├── input
├── src
│ ├── 0.png
│ ├── config.py
│ ├── data.py
│ ├── bert_emb.py
│ └── dataset_check.py
└── README.md
├── examples
├── bird1.jpg
├── bird2.jpg
├── bird3.jpg
├── bird4.jpg
├── bird5.jpg
├── flower1.jpg
├── flower2.jpg
├── flower3.jpg
├── flower4.jpg
├── flower5.jpg
├── framework.jpg
├── framework.png
└── captions
├── requirements.txt
├── src
├── environment.yml
├── util.py
├── dataset.py
├── engine.py
├── train.py
├── args.py
└── layers.py
├── cfg
├── s2.yml
└── s1.yml
├── LICENSE
├── .gitignore
└── README.md
/input/src/0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sahilkhose/StackGAN-BERT/HEAD/input/src/0.png
--------------------------------------------------------------------------------
/examples/bird1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sahilkhose/StackGAN-BERT/HEAD/examples/bird1.jpg
--------------------------------------------------------------------------------
/examples/bird2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sahilkhose/StackGAN-BERT/HEAD/examples/bird2.jpg
--------------------------------------------------------------------------------
/examples/bird3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sahilkhose/StackGAN-BERT/HEAD/examples/bird3.jpg
--------------------------------------------------------------------------------
/examples/bird4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sahilkhose/StackGAN-BERT/HEAD/examples/bird4.jpg
--------------------------------------------------------------------------------
/examples/bird5.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sahilkhose/StackGAN-BERT/HEAD/examples/bird5.jpg
--------------------------------------------------------------------------------
/examples/flower1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sahilkhose/StackGAN-BERT/HEAD/examples/flower1.jpg
--------------------------------------------------------------------------------
/examples/flower2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sahilkhose/StackGAN-BERT/HEAD/examples/flower2.jpg
--------------------------------------------------------------------------------
/examples/flower3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sahilkhose/StackGAN-BERT/HEAD/examples/flower3.jpg
--------------------------------------------------------------------------------
/examples/flower4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sahilkhose/StackGAN-BERT/HEAD/examples/flower4.jpg
--------------------------------------------------------------------------------
/examples/flower5.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sahilkhose/StackGAN-BERT/HEAD/examples/flower5.jpg
--------------------------------------------------------------------------------
/examples/framework.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sahilkhose/StackGAN-BERT/HEAD/examples/framework.jpg
--------------------------------------------------------------------------------
/examples/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sahilkhose/StackGAN-BERT/HEAD/examples/framework.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | ujson
2 | numpy
3 | spacy
4 | tensorboard
5 | tensorflow
6 | tensorboardX
7 | tqdm
8 | urllib3
9 | torch
10 | torchvision
11 | transformers
12 | torchfile
--------------------------------------------------------------------------------
/src/environment.yml:
--------------------------------------------------------------------------------
1 | name: ganctober
2 | channels:
3 | - pytorch
4 | - conda-forge
5 | - defaults
6 | dependencies:
7 | - python=3.6
8 | - ujson
9 | - numpy
10 | - pip
11 | - spacy=2.0.16
12 | - tensorboard
13 | - tensorflow
14 | - tensorboardX
15 | - tqdm
16 | - urllib3
17 | - pytorch=1.0.0
18 | - pip:
19 | - torch==1.0.0
20 |
--------------------------------------------------------------------------------
/cfg/s2.yml:
--------------------------------------------------------------------------------
1 | CONFIG_NAME: "stage2"
2 | STAGE: 2
3 | train_bs: 4 # 64
4 | train_workers: 1
5 | embedding_type: "bert"
6 | TRAIN_MAX_EPOCH: 600
7 | TRAIN_GEN_LR: 0.0002 # 2e-4
8 | TRAIN_DISC_LR: 0.0002 # 2e-4
9 | TRAIN_LR_DECAY_EPOCH: 100
10 | TRAIN_SNAPSHOT_INTERVAL: 10 # 2000
11 | STAGE1_G_path: "../output/model/stage1_netG_epoch_600.pth"
12 | NET_G_path: ""
13 | NET_D_path: ""
14 |
15 | pretrained_epoch: 600
--------------------------------------------------------------------------------
/cfg/s1.yml:
--------------------------------------------------------------------------------
1 | CONFIG_NAME: "stage1"
2 | STAGE: 1
3 | train_bs: 128 # 128 fits / 64 standard
4 | train_workers: 4
5 | embedding_type: "bert" # cnn-rnn / bert
6 | TRAIN_MAX_EPOCH: 600
7 | TRAIN_GEN_LR: 0.0002 # 2e-4
8 | TRAIN_DISC_LR: 0.0002 # 2e-4
9 | TRAIN_LR_DECAY_EPOCH: 50
10 | TRAIN_SNAPSHOT_INTERVAL: 2000
11 | NET_G_path: ""
12 | NET_D_path: ""
13 | # NET_G_path: "../old_outputs/output_1_0/model/netG_epoch_120.pth"
14 | # NET_D_path: "../old_outputs/output_1_0/model/netD_epoch_120.pth"
--------------------------------------------------------------------------------
/input/src/config.py:
--------------------------------------------------------------------------------
1 | """All the paths and arguments.
2 |
3 | Authors:
4 | Abhiraj Tiwari (abhirajtiwari@gmail.com)
5 | Sahil Khose (sahilkhose18@gmail.com)
6 | """
7 | # import transformers
8 |
9 | # BERT_PATH = "../data/bert_base_uncased"
10 | ANNOTATIONS = "../data/birds/text_c10"
11 | ANNOTATION_EMB = "../data/birds/embeddings"
12 | IMAGE_DIR = "../data/CUB_200_2011/images"
13 | EMB_LEN = "../data/emb_lens.txt"
14 | SENT_LEN = "../data/sent_len.txt"
15 |
16 | DEVICE = "cuda"
17 |
18 | ANNOTATIONS_URL = "https://drive.google.com/open?id=0B3y_msrWZaXLT1BZdVdycDY5TEE"
19 | CUB_URL = "http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz"
--------------------------------------------------------------------------------
/examples/captions:
--------------------------------------------------------------------------------
1 | A small yellow bird with a black crown and a short black pointed beak
2 | This small bird has a white breast, light grey head, and black wings and tail
3 | This bird is white, black, and brown in color, with a brown beak
4 | A white bird with a black crown and yellow beak
5 | This bird is completely red with black wings and pointy beak
6 |
7 |
8 | This flower has overlapping pink pointed petals surrounding a ring of short yellow filaments
9 | This flower has long thin yellow petals and a lot of yellow anthers in the center
10 | This flower is white, pink, and yellow in color, and has petals that are multi colored
11 | This flower is pink, white, and yellow in color, and has petals that are striped
12 | This flower has white petals with a yellow tip and a yellow pistil
13 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Sahil Khose
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/input/README.md:
--------------------------------------------------------------------------------
1 | ## :four_leaf_clover: New
2 | - run the following to automate the dataset download
3 | ```bash
4 | cd input/src
5 | python3 data.py
6 | ```
7 | ```
8 | ganctober
9 | │ LICENSE
10 | │ README.md
11 | │ requirements.txt
12 | │
13 | └──>cfg
14 | │
15 | └──>examples
16 | │
17 | └──>input
18 | │ │
19 | │ │
20 | │ └──>data
21 | │ | │
22 | │ | └──>bert_base_uncased
23 | | | |
24 | │ | └──>birds
25 | │ | |
26 | | | └──>CUB_200_2011
27 | │ │
28 | │ └──>src
29 | | | bert_emb.py
30 | | | config.py
31 | | | data.py
32 | | | dataset_check.py
33 | |
34 | |
35 | └──>old_outputs
36 | |
37 | └──>output
38 | |
39 | └──>src
40 | │ │ args.py
41 | | | dataset.py
42 | │ │ engine.py
43 | │ │ environment.yml
44 | │ │ layers.py
45 | │ │ train.py
46 | │ │ util.py
47 | ```
48 | --------------------------------------------------------------------------------------------
49 |
50 | ## Dataset (Old)
51 | - Download the following and extract/move them to the mentioned directories so that your workspace is similar to the one in the figure.
52 | - Download the The Caltech-UCSD Birds-200-2011 (CUB) Dataset from: http://www.vision.caltech.edu/visipedia/CUB-200-2011.html and extract it in `input/data/` to create `input/data/CUB_200_2011` directory which contains `images` directory with the images we need for our task.
53 | - Read the README about the dataset on the webiste
54 | - Download the text descriptions from: https://drive.google.com/open?id=0B3y_msrWZaXLT1BZdVdycDY5TEE and extract it in `input/data/` to create `input/data/birds` directory which contains `text_c10` directory which contains all the annotations needed for our task.
55 | - Download bert_base_uncased from: https://www.kaggle.com/abhishek/bert-base-uncased and extract it in `input/data/` to create `input/data/bert_base_uncased` to create annotation bert embeddings
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # input data and models
2 | input/data/*
3 | output/
4 | old_outputs/
5 | *.tar
6 |
7 | # data files
8 | *.csv
9 | *.pt
10 | *.npy
11 | *.txt
12 | !requirements.txt
13 | *.h5
14 | *.pkl
15 | *.pth
16 |
17 | ######################################################################
18 | ######################################################################
19 | # Byte-compiled / optimized / DLL files
20 | __pycache__/
21 | *.py[cod]
22 | *$py.class
23 |
24 | # C extensions
25 | *.so
26 |
27 | # Distribution / packaging
28 | .Python
29 | build/
30 | develop-eggs/
31 | dist/
32 | downloads/
33 | eggs/
34 | .eggs/
35 | lib/
36 | lib64/
37 | parts/
38 | sdist/
39 | var/
40 | wheels/
41 | share/python-wheels/
42 | *.egg-info/
43 | .installed.cfg
44 | *.egg
45 | MANIFEST
46 |
47 | # PyInstaller
48 | # Usually these files are written by a python script from a template
49 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
50 | *.manifest
51 | *.spec
52 |
53 | # Installer logs
54 | pip-log.txt
55 | pip-delete-this-directory.txt
56 |
57 | # Unit test / coverage reports
58 | htmlcov/
59 | .tox/
60 | .nox/
61 | .coverage
62 | .coverage.*
63 | .cache
64 | nosetests.xml
65 | coverage.xml
66 | *.cover
67 | *.py,cover
68 | .hypothesis/
69 | .pytest_cache/
70 | cover/
71 |
72 | # Translations
73 | *.mo
74 | *.pot
75 |
76 | # Django stuff:
77 | *.log
78 | local_settings.py
79 | db.sqlite3
80 | db.sqlite3-journal
81 |
82 | # Flask stuff:
83 | instance/
84 | .webassets-cache
85 |
86 | # Scrapy stuff:
87 | .scrapy
88 |
89 | # Sphinx documentation
90 | docs/_build/
91 |
92 | # PyBuilder
93 | .pybuilder/
94 | target/
95 |
96 | # Jupyter Notebook
97 | .ipynb_checkpoints
98 |
99 | # IPython
100 | profile_default/
101 | ipython_config.py
102 |
103 | # pyenv
104 | # For a library or package, you might want to ignore these files since the code is
105 | # intended to run in multiple environments; otherwise, check them in:
106 | # .python-version
107 |
108 | # pipenv
109 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
110 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
111 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
112 | # install all needed dependencies.
113 | #Pipfile.lock
114 |
115 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
116 | __pypackages__/
117 |
118 | # Celery stuff
119 | celerybeat-schedule
120 | celerybeat.pid
121 |
122 | # SageMath parsed files
123 | *.sage.py
124 |
125 | # Environments
126 | .env
127 | .venv
128 | env/
129 | venv/
130 | ENV/
131 | env.bak/
132 | venv.bak/
133 |
134 | # Spyder project settings
135 | .spyderproject
136 | .spyproject
137 |
138 | # Rope project settings
139 | .ropeproject
140 |
141 | # mkdocs documentation
142 | /site
143 |
144 | # mypy
145 | .mypy_cache/
146 | .dmypy.json
147 | dmypy.json
148 |
149 | # Pyre type checker
150 | .pyre/
151 |
152 | # pytype static type analyzer
153 | .pytype/
154 |
155 | # Cython debug symbols
156 | cython_debug/
157 | ganctober.code-workspace
158 |
--------------------------------------------------------------------------------
/src/util.py:
--------------------------------------------------------------------------------
1 | """Utility classes and methods.
2 |
3 | Authors:
4 | Abhiraj Tiwari (abhirajtiwari@gmail.com)
5 | Sahil Khose (sahilkhose18@gmail.com)
6 | """
7 | import args
8 |
9 | import matplotlib.pyplot as plt
10 | import os
11 | import pandas as pd
12 | import torch
13 | import torchvision.utils as vutils
14 |
15 | from json import dumps
16 | print("__"*80)
17 |
18 | #TODO fetch saved generated images during training and their corresponding annotations
19 |
20 | def save_img_results(data_img, fake, epoch, args):
21 | num = args.VIS_COUNT
22 | fake = fake[0:num]
23 | # data_img is changed to [0, 1]
24 | if data_img is not None:
25 | data_img = data_img[0:num]
26 | vutils.save_image(data_img, os.path.join(args.image_save_dir, "real_samples.png"), normalize=True)
27 | # fake data is stil [-1, 1]
28 | vutils.save_image(fake.data, os.path.join(args.image_save_dir, f"fake_samples_epoch_{epoch:04}.png"), normalize=True)
29 | else:
30 | vutils.save_image(fake.data, os.path.join(args.image_save_dir, f"lr_fake_samples_epoch_{epoch:04}.png"), normalize=True)
31 |
32 | def save_model(netG, netD, epoch, args):
33 | torch.save(netG.state_dict(), os.path.join(args.model_dir, f"netG_epoch_{epoch}.pth"))
34 | torch.save(netD.state_dict(), os.path.join(args.model_dir, f"netD_epoch_{epoch}.pth")) #? implementation has saved just the last disc? decision?
35 | print("Save G/D models")
36 |
37 |
38 | def make_dir(path):
39 | if not os.path.exists(path):
40 | os.makedirs(path)
41 |
42 |
43 | def check_dataset(training_set):
44 | t, i, b = training_set[1]
45 | print("Bert emb shape: ", t.shape)
46 | print("bbox: ", b)
47 | plt.imshow(i)
48 | plt.show()
49 | print("__"*80)
50 |
51 | def check_args():
52 | """
53 | To test args.py
54 | """
55 | print("get_data_args:")
56 | data_args = args.get_all_args()
57 | print(f'Args: {dumps(vars(data_args), indent=4, sort_keys=True)}')
58 | print("__"*80)
59 |
60 | print("To fetch arguments:")
61 | print(data_args.device)
62 | print(data_args.images_dir)
63 | print("__"*80)
64 |
65 | #####################################################################################
66 | print("To fetch images, text embeddings, bert embeddings:")
67 | df_train_filenames = pd.read_pickle(data_args.train_filenames)
68 | # print(df_train_filenames[0]) # List[str] : len train -> 8855, len test -> 2933
69 | image_path = os.path.join(data_args.images_dir, df_train_filenames[0] + ".jpg")
70 | text_path = os.path.join(data_args.annotations_dir, df_train_filenames[0] + ".txt")
71 | bert_path = os.path.join(data_args.bert_annotations_dir, df_train_filenames[0], "0.pt")
72 |
73 |
74 | print("\nBird type: ")
75 | print(df_train_filenames[0].split("/")[0])
76 |
77 | print("\nAnnotations of the bird image: \n")
78 | [print(f"{idx}: {ele}") for idx, ele in enumerate(open(text_path).read().split("\n")[:-1])]
79 |
80 | print("\nShape of bert embedding of annotation no 0:")
81 | emb = torch.load(bert_path)
82 | print(emb.shape) # (1, 768)
83 |
84 | img = plt.imread(image_path)
85 | plt.imshow(img)
86 | plt.show()
87 |
88 |
89 |
90 |
91 | if __name__ == "__main__":
92 | check_args()
93 | # make_dir("../output/model")
94 |
--------------------------------------------------------------------------------
/input/src/data.py:
--------------------------------------------------------------------------------
1 | """Downloads data
2 |
3 | Authors:
4 | Abhiraj Tiwari (abhirajtiwari@gmail.com)
5 | Sahil Khose (sahilkhose18@gmail.com)
6 | """
7 | import requests
8 | import os
9 | import tqdm
10 |
11 | class GoogleDriveDownloader(object):
12 | """
13 | Downloading a file stored on Google Drive by its URL.
14 | If the link is pointing to another resource, the redirect chain is being expanded.
15 | Returns the output path.
16 | """
17 |
18 | base_url = 'https://docs.google.com/uc?export=download'
19 | chunk_size = 32768
20 |
21 | def __init__(self, url, out_dir):
22 | super().__init__()
23 |
24 | self.out_name = url.rsplit('/', 1)[-1]
25 | self.url = self._get_redirect_url(url)
26 | self.out_dir = out_dir
27 |
28 | @staticmethod
29 | def _get_redirect_url(url):
30 | response = requests.get(url)
31 | if response.url != url and response.url is not None:
32 | redirect_url = response.url
33 | return redirect_url
34 | else:
35 | return url
36 |
37 | @staticmethod
38 | def _get_confirm_token(response):
39 | for key, value in response.cookies.items():
40 | if key.startswith('download_warning'):
41 | return value
42 | return None
43 |
44 | def _save_response_content(self, response):
45 | with open(self.fpath, 'wb') as f:
46 | bar = tqdm.tqdm(total=None)
47 | progress = 0
48 | for chunk in response.iter_content(self.chunk_size):
49 | if chunk:
50 | f.write(chunk)
51 | progress += len(chunk)
52 | bar.update(progress - bar.n)
53 | bar.close()
54 |
55 | @property
56 | def file_id(self):
57 | return self.url.split('?')[0].split('/')[-2]
58 |
59 | @property
60 | def fpath(self):
61 | return os.path.join(self.out_dir, self.out_name)
62 |
63 | def download(self):
64 | os.makedirs(self.out_dir, exist_ok=True)
65 |
66 | if os.path.isfile(self.fpath):
67 | print('File is downloaded yet:', self.fpath)
68 | else:
69 | session = requests.Session()
70 | response = session.get(self.base_url, params={'id': self.file_id}, stream=True)
71 | token = self._get_confirm_token(response)
72 |
73 | if token:
74 | response = session.get(self.base_url, params={'id': self.file_id, 'confirm': token}, stream=True)
75 | else:
76 | raise RuntimeError()
77 |
78 | self._save_response_content(response)
79 |
80 | return self.fpath
81 |
82 |
83 | def main():
84 | os.makedirs("../data/", exist_ok=True)
85 | url = 'http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz'
86 | dl = GoogleDriveDownloader(url, '../data/')
87 | dl.download()
88 |
89 | url_b = 'https://drive.google.com/file/d/1MVfYF0qVgKHTQKFdA7lexGWnRIs7Ax9c/view?usp=sharing'
90 | dl_b = GoogleDriveDownloader(url_b, '../data/')
91 | dl_b.download()
92 |
93 | os.system("unzip ../data/birds.zip -d ../data/")
94 | os.system("tar -xvf ../data/CUB_200_2011.tgz -C ../data/")
95 |
96 |
97 |
98 | if __name__ == '__main__':
99 | main()
100 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # StackGAN
2 |
3 | - PyTorch implementation of the paper [StackGAN: Text to Photo-realistic Image Synthesis with Stacked Generative Adversarial Networks](https://arxiv.org/pdf/1612.03242.pdf) by Han Zhang, Tao Xu, Hongsheng Li, Shaoting Zhang, Xiaogang Wang, Xiaolei Huang, Dimitris Metaxas.
4 |
5 | ## :bulb: What's new?
6 | - We use BERT embeddings for the text description instead of the char-CNN-RNN text embeddings that were used in the paper implementation.
7 |
8 |
9 | ## Pretrained model
10 | - [Stage 1](https://drive.google.com/drive/folders/14AyNcu7oZJe2aMevynAbYIpMKN7I3yHT?usp=sharing) trained using BERT embeddings instead of the orignal char-CNN-RNN text embeddings
11 | - [Stage 2](https://drive.google.com/drive/folders/1Pyndsp9oraE15ssD4MZJBVsyLW1ECCIi?usp=sharing) trained using BERT embeddings instead of the orignal char-CNN-RNN text embeddings
12 |
13 | ## Paper examples
14 | #### :bird: Examples for birds (char-CNN-RNN embeddings), more on [youtube](https://youtu.be/93yaf_kE0Fg):
15 | 
16 | 
17 | 
18 | 
19 |
20 | --------------------------------------------------------------------------------------------
21 |
22 | #### :sunflower: Examples for flowers (char-CNN-RNN embeddings), more on [youtube](https://youtu.be/SuRyL5vhCIM):
23 | 
24 | 
25 | 
26 | 
27 |
28 | --------------------------------------------------------------------------------------------
29 |
30 | ## :clipboard: Dependencies
31 | ```bash
32 | git clone https://github.com/sahilkhose/StackGAN-BERT.git
33 | pip3 install -r requirements.txt
34 | ```
35 |
36 | ## Dataset
37 | Check instructions in `/input/README.md`
38 | ```bash
39 | cd input/src
40 | python3 data.py
41 | ```
42 |
43 | ## Generating BERT embeddings of annotations
44 | Change the DEVICE to `cpu` in `input/src/config.py` if `cuda` is not available
45 | ```bash
46 | python3 bert_emb.py
47 | ```
48 |
49 | ## :wrench: Training
50 | ```bash
51 | cd ../../src
52 | ```
53 | Option 1: CLI args training `src/args.py`
54 | ```bash
55 | python3 train.py --TRAIN_MAX_EPOCH 10
56 | ```
57 | Option 2: yaml args training `cfg/s1.yml` and `cfg/s2.yml`
58 | ```bash
59 | python3 train.py --conf ../cfg/s1.yml
60 |
61 | mkdir ../old_outputs
62 | mv ../output ../old_outputs/output_stage-1
63 |
64 | python3 train.py --conf ../cfg/s2.yml
65 |
66 | mv ../output ../old_outputs/output_stage-2
67 | ```
68 | To load the tensorboard
69 | ```bash
70 | tensorboard --logdir=../output
71 | ```
72 |
73 | --------------------------------------------------------------------------------------------
74 |
75 | ## :books: Citing StackGAN
76 | If you find StackGAN useful in your research, please consider citing:
77 |
78 | ```
79 | @inproceedings{han2017stackgan,
80 | Author = {Han Zhang and Tao Xu and Hongsheng Li and Shaoting Zhang and Xiaogang Wang and Xiaolei Huang and Dimitris Metaxas},
81 | Title = {StackGAN: Text to Photo-realistic Image Synthesis with Stacked Generative Adversarial Networks},
82 | Year = {2017},
83 | booktitle = {{ICCV}},
84 | }
85 | ```
86 |
87 | **Follow-up work**
88 |
89 | - [StackGAN++: Realistic Image Synthesis with Stacked Generative Adversarial Networks](https://arxiv.org/abs/1710.10916)
90 | - [AttnGAN: Fine-Grained Text to Image Generation with Attentional Generative Adversarial Networks](https://arxiv.org/abs/1711.10485) [[supplementary]](https://1drv.ms/b/s!Aj4exx_cRA4ghK5-kUG-EqH7hgknUA) [[code]](https://github.com/taoxugit/AttnGAN)
91 |
92 | **References**
93 |
94 | - Generative Adversarial Text-to-Image Synthesis [Paper](https://arxiv.org/abs/1605.05396) [Code](https://github.com/reedscot/icml2016)
95 | - Learning Deep Representations of Fine-grained Visual Descriptions [Paper](https://arxiv.org/abs/1605.05395) [Code](https://github.com/reedscot/cvpr2016)
96 |
--------------------------------------------------------------------------------
/src/dataset.py:
--------------------------------------------------------------------------------
1 | """CUB Dataset class.
2 |
3 | Authors:
4 | Abhiraj Tiwari (abhirajtiwari@gmail.com)
5 | Sahil Khose (sahilkhose18@gmail.com)
6 | """
7 | import args
8 |
9 | import cv2
10 | import matplotlib.pyplot as plt
11 | import numpy as np
12 | import os
13 | import pandas as pd
14 | import pickle
15 | import PIL
16 | import torch
17 |
18 | from torchvision import transforms
19 |
20 | from PIL import Image
21 | print("__"*80)
22 |
23 |
24 | class CUBDataset(torch.utils.data.Dataset):
25 | def __init__(self, pickl_file, img_dir, bert_emb=None, cnn_emb=None, stage=1):
26 | self.file_names = pd.read_pickle(pickl_file)[:-23]
27 | self.img_dir = img_dir
28 | self.bert_emb = bert_emb
29 | if cnn_emb is not None:
30 | self.cnn_emb = np.array(pickle.load(open(cnn_emb, "rb"), encoding='latin1')) # to switch to cnn-rnn embeddings
31 | else:
32 | self.cnn_emb = cnn_emb
33 | self.stage = stage
34 | self.f_to_bbox = dict_bbox()
35 |
36 | def __len__(self):
37 | # Total number of samples
38 | return len(self.file_names)
39 |
40 | def __getitem__(self, index):
41 | # Select sample:
42 | data_id = str(self.file_names[index])
43 |
44 | # Fetch text emb, image, bbox:
45 | idx = np.random.randint(0, 9)
46 | if self.cnn_emb is not None:
47 | text_emb = torch.tensor(self.cnn_emb[index][idx], dtype=torch.float) # (1024)
48 | else:
49 | text_emb = torch.load(os.path.join(self.bert_emb, data_id)+f"/{idx}.pt", map_location="cpu")
50 | text_emb = text_emb.squeeze(0) # (768)
51 |
52 | bbox = self.f_to_bbox[data_id]
53 | if self.stage == 1:
54 | image = get_img(img_path=os.path.join(self.img_dir, data_id) + ".jpg", bbox=bbox, image_size=(64, 64))
55 | else:
56 | image = get_img(img_path=os.path.join(self.img_dir, data_id) + ".jpg", bbox=bbox, image_size=(256, 256))
57 | # image = torch.tensor(np.array(image), dtype=torch.float)
58 | image = transforms.ToTensor()(image)
59 |
60 | return text_emb, image
61 |
62 |
63 | def dict_bbox():
64 | """
65 | returns filename to bbox dict
66 | """
67 | data_args = args.get_all_args()
68 |
69 | df_bbox = pd.read_csv(data_args.bounding_boxes, delim_whitespace=True, header=None).astype(int)
70 | df_filenames = pd.read_csv(data_args.images_id_file, delim_whitespace=True, header=None)
71 |
72 | filenames = df_filenames[1].tolist()
73 |
74 | filename_bbox = {}
75 | for i in range(len(filenames)):
76 | bbox = df_bbox.iloc[i][1:].tolist()
77 | key = filenames[i].replace(".jpg", "")
78 | filename_bbox[key] = bbox
79 |
80 | return filename_bbox
81 |
82 | def get_img(img_path, bbox, image_size):
83 | """
84 | Load and resize image
85 | """
86 | img = Image.open(img_path).convert('RGB')
87 | width, height = img.size
88 | if bbox is not None:
89 | R = int(np.maximum(bbox[2], bbox[3]) * 0.75)
90 | center_x = int((2 * bbox[0] + bbox[2]) / 2)
91 | center_y = int((2 * bbox[1] + bbox[3]) / 2)
92 | y1 = np.maximum(0, center_y - R)
93 | y2 = np.minimum(height, center_y + R)
94 | x1 = np.maximum(0, center_x - R)
95 | x2 = np.minimum(width, center_x + R)
96 | img = img.crop([x1, y1, x2, y2])
97 | img = img.resize(image_size, PIL.Image.BILINEAR)
98 | return img
99 |
100 | if __name__ == "__main__":
101 | data_args = args.get_all_args()
102 | train_filenames = data_args.train_filenames
103 | test_filenames = data_args.test_filenames
104 |
105 | bert_embed = False
106 |
107 | if bert_embed:
108 | ###* Bert embeddings dataset:
109 | dataset_test = CUBDataset(train_filenames, data_args.images_dir, bert_emb=data_args.bert_annotations_dir)
110 | else:
111 | ###* cnn-rnn embeddings dataset:
112 | # cnn_embeddings = np.array(pickle.load(open(data_args.cnn_annotations_emb_train, "rb"), encoding='latin1'))
113 | dataset_test = CUBDataset(train_filenames, data_args.images_dir, cnn_emb=data_args.cnn_annotations_emb_train)
114 |
115 |
116 | t, i = dataset_test[1]
117 | if bert_embed:
118 | print("Bert emb shape: ", t.shape)
119 | else:
120 | print("rnn-cnn emb shape: ", t.shape)
121 | print("Image shape: ", i.shape)
122 | plt.imshow(i.permute(1, 2, 0))
123 | plt.show()
124 |
125 | ###########################################################
126 | # filename = "001.Black_footed_Albatross/Black_Footed_Albatross_0046_18"
127 | # f_to_bbox = dict_bbox()
128 | # print(f_to_bbox[filename])
129 |
--------------------------------------------------------------------------------
/input/src/bert_emb.py:
--------------------------------------------------------------------------------
1 | '''
2 | Generates bert embeddings from annotations.
3 |
4 | Authors:
5 | Abhiraj Tiwari (abhirajtiwari@gmail.com)
6 | Sahil Khose (sahilkhose18@gmail.com)
7 | '''
8 |
9 | import config
10 |
11 | import os
12 | import torch
13 | import logging
14 | logging.basicConfig(level=logging.ERROR)
15 |
16 | # from transformers import BertTokenizer, BertModel
17 | from transformers import AutoTokenizer, AutoModelForMaskedLM
18 | from tqdm import tqdm
19 |
20 | print("__"*80)
21 | print("Imports finished.")
22 | print("Loading BERT Tokenizer and model...")
23 | ###############################################################################################
24 | # tokenizer = BertTokenizer.from_pretrained(config.BERT_PATH, do_lower_case=True)
25 | # model = BertModel.from_pretrained(
26 | # config.BERT_PATH, output_hidden_states=True).to(config.DEVICE)
27 |
28 | tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased", do_lower_case=True)
29 | model = AutoModelForMaskedLM.from_pretrained("bert-base-uncased", output_hidden_states=True).to(config.DEVICE)
30 | model.eval()
31 |
32 | ###############################################################################################
33 | print("BERT tokenizer and model loaded.")
34 |
35 | def sent_emb(sent):
36 | encoded_dict = tokenizer.encode_plus(
37 | sent,
38 | add_special_tokens=True,
39 | max_length=128, # This is changed.
40 | pad_to_max_length=True,
41 | return_attention_mask=True,
42 | return_tensors='pt',
43 | truncation=True
44 | )
45 |
46 | input_ids = encoded_dict['input_ids']
47 | attention_masks = encoded_dict['attention_mask']
48 |
49 | with torch.no_grad():
50 | ### token embeddings:
51 | outputs = model(input_ids.to(config.DEVICE),
52 | attention_masks.to(config.DEVICE))
53 | hidden_states = outputs[2]
54 | token_embeddings = torch.stack(hidden_states, dim=0)
55 |
56 | ### sentence embeddings:
57 | token_vecs = hidden_states[-2][0]
58 | sentence_embedding = torch.mean(token_vecs, dim=0)
59 | sentence_embedding = sentence_embedding.view(1, -1)
60 | return sentence_embedding
61 |
62 |
63 | def max_len():
64 | max_len = 0
65 | emb_lens = []
66 | for bird_type in tqdm(sorted(os.listdir(config.ANNOTATIONS)), total=len(os.listdir(config.ANNOTATIONS))):
67 | for file in sorted(os.listdir(os.path.join(config.ANNOTATIONS, bird_type))):
68 | text = open(os.path.join(config.ANNOTATIONS, bird_type, file), "r").read().split('\n')[:-1]
69 | for annotation in text:
70 | input_ids = tokenizer.encode(annotation, add_special_tokens=True)
71 | r = len(input_ids)
72 | max_len = max(max_len, r)
73 | emb_lens.append(r)
74 | print(f"Saving emb lens to {config.EMB_LEN}")
75 | file = open(config.EMB_LEN, "w")
76 | for ele in emb_lens:
77 | file.write(str(ele) + "\n")
78 | return max_len
79 |
80 | def make_dir(dir_name):
81 | if not os.path.exists(dir_name):
82 | os.makedirs(dir_name)
83 |
84 | def generate_text_embs():
85 | print(f"Saving text embeddings to {config.ANNOTATION_EMB}")
86 | try:
87 | make_dir(config.ANNOTATION_EMB)
88 | for bird_type in tqdm(sorted(os.listdir(config.ANNOTATIONS)), total=len(os.listdir(config.ANNOTATIONS))):
89 | make_dir(os.path.join(config.ANNOTATION_EMB, bird_type))
90 | for file in sorted(os.listdir(os.path.join(config.ANNOTATIONS, bird_type))):
91 | make_dir(os.path.join(config.ANNOTATION_EMB, bird_type, file.replace(".txt", "")))
92 | text = open(os.path.join(config.ANNOTATIONS, bird_type, file), "r").read().split('\n')[:-1]
93 | for annotation_id, annotation in enumerate(text):
94 | emb = sent_emb(annotation)
95 | torch.save(emb, os.path.join(config.ANNOTATION_EMB, bird_type, file.replace(".txt", ""), str(annotation_id) + ".pt"))
96 | except Exception as e:
97 | print(f"Error in {bird_type}/{file}")
98 | print(e)
99 |
100 |
101 | if __name__ == "__main__":
102 | # To determine mex_length in tokenizer.encode_plus()
103 | # print("Max length of annotations: ", max_len()) # 80
104 |
105 | # Generate sentence embeddings:
106 | if os.path.exists(config.ANNOTATION_EMB) and len(os.listdir(config.ANNOTATION_EMB)) == len(os.listdir(config.IMAGE_DIR)): # checking if we have all the embeddings folders created already
107 | print("Bert embeddings already exist. Skipping...")
108 | else:
109 | print("Generating bert embs...")
110 | generate_text_embs()
--------------------------------------------------------------------------------
/input/src/dataset_check.py:
--------------------------------------------------------------------------------
1 | """Displays the 10 annotations and the corresponding picture.
2 |
3 | Authors:
4 | Abhiraj Tiwari (abhirajtiwari@gmail.com)
5 | Sahil Khose (sahilkhose18@gmail.com)
6 | """
7 | import config
8 |
9 | import numpy as np
10 | import os
11 | import matplotlib.pyplot as plt
12 | import pickle
13 | import torch
14 |
15 | from PIL import Image
16 | from scipy.spatial.distance import cosine
17 | from tqdm import tqdm
18 | print("__"*80)
19 | print("Imports finished")
20 | print("__"*80)
21 |
22 |
23 | def display_specific(bird_type_no=0, file_no=0, file_idx=None):
24 | """
25 | Prints annotations and displays images of a specific bird
26 | """
27 | bird_type = sorted(os.listdir(config.IMAGE_DIR))[bird_type_no]
28 | file = sorted(os.listdir(os.path.join(config.ANNOTATIONS, bird_type)))[file_no]
29 |
30 | if file_idx is None:
31 | filename = os.path.join(bird_type, file)
32 | else:
33 | filenames = np.array(pickle.load(open("../data/birds/train/filenames.pickle", "rb"), encoding='latin1'))
34 | filename = filenames[file_idx]
35 | filename += ".txt"
36 |
37 | print(f"\nFile: {filename}\n")
38 |
39 | text = open(os.path.join(config.ANNOTATIONS, filename), "r").read().split('\n')[:-1]
40 | [print(f"{idx}: {line}") for idx, line in enumerate(text)]
41 | filename = filename.replace(".txt", ".jpg")
42 | plt.imshow(plt.imread(os.path.join(config.IMAGE_DIR, filename)))
43 | plt.show()
44 |
45 | def compare_bert_emb(file_1, file_2, emb_no=0):
46 | emb_1 = torch.load(os.path.join(config.ANNOTATION_EMB, file_1, f"{emb_no}.pt"), map_location="cpu")
47 | emb_2 = torch.load(os.path.join(config.ANNOTATION_EMB, file_2, f"{emb_no}.pt"), map_location="cpu")
48 | # print(emb_1.shape) # (1, 768)
49 |
50 | bert_sim = 1 - cosine(emb_1, emb_2)
51 | print(f"cosine similarity bert emb: {bert_sim:.2f}")
52 |
53 | def compare_cnn_emb(emb_idx_1, emb_idx_2, emb_no=0):
54 | embeddings = np.array(pickle.load(open("../data/birds/train/char-CNN-RNN-embeddings.pickle", "rb"), encoding='latin1'))
55 | # print(embeddings.shape) # (8855, 10, 1024)
56 |
57 | cnn_sim = 1 - cosine(embeddings[emb_idx_1][emb_no], embeddings[emb_idx_2][emb_no])
58 | print(f"cosine similarity cnn embs: {cnn_sim:.2f}")
59 |
60 | def compare_embedding_quality(emb_idx_1=0, emb_idx_2=1, emb_no=0):
61 | ###* Filenames to fetch embs:
62 | filenames = np.array(pickle.load(open("../data/birds/train/filenames.pickle", "rb"), encoding='latin1'))
63 | # print(filenames.shape) # (8855, )
64 |
65 | ###* File paths:
66 | file_1 = filenames[emb_idx_1]
67 | file_2 = filenames[emb_idx_2]
68 |
69 | print(f"File 1: {file_1}")
70 | print(f"File 2: {file_2}\n")
71 |
72 | ###* Annotations:
73 | text1 = open(os.path.join(config.ANNOTATIONS, file_1+".txt"), "r").read().split('\n')[:-1]
74 | text2 = open(os.path.join(config.ANNOTATIONS, file_2+".txt"), "r").read().split('\n')[:-1]
75 | print("Annotation 1: ", text1[emb_no])
76 | print("Annotation 2: ", text2[emb_no])
77 | print()
78 |
79 | ###* Cosine similarity:
80 | compare_cnn_emb(emb_idx_1, emb_idx_2, emb_no=0)
81 | compare_bert_emb(file_1, file_2, emb_no=0)
82 |
83 | ###* Display images:
84 | fig = plt.figure()
85 | fig.add_subplot(1, 2, 1)
86 | plt.imshow(plt.imread(os.path.join(config.IMAGE_DIR, file_1 + ".jpg")))
87 | fig.add_subplot(1, 2, 2)
88 | plt.imshow(plt.imread(os.path.join(config.IMAGE_DIR, file_2 + ".jpg")))
89 | # plt.show()
90 |
91 | def check_model(file_idx, model):
92 | import sys
93 | sys.path.insert(1, "../../src/")
94 | import layers
95 | emb_no = 0
96 | ###* load the models
97 | netG = layers.Stage1Generator().cuda()
98 | netG.load_state_dict(torch.load(model))
99 | netG.eval()
100 | with torch.no_grad():
101 | ###* load the embeddings
102 | filenames = np.array(pickle.load(open("../data/birds/train/filenames.pickle", "rb"), encoding='latin1'))
103 | file_name = filenames[file_idx]
104 | emb = torch.load(os.path.join(config.ANNOTATION_EMB, file_name, f"{emb_no}.pt"))
105 |
106 | ###* Forward pass
107 | print(emb.shape) # (1, 768)
108 | noise = torch.autograd.Variable(torch.FloatTensor(1, 100)).cuda()
109 | noise.data.normal_(0, 1)
110 | _, fake_image, mu, logvar = netG(emb, noise)
111 | fake_image = fake_image.squeeze(0)
112 | print(fake_image.shape) #(3, 64, 64)
113 |
114 | im_save(fake_image, count=0)
115 | return fake_image
116 |
117 | def im_save(fake_img, count=0):
118 | save_name = f"{count}.png"
119 | im = fake_img.cpu().numpy()
120 | im = (im + 1.0) * 127.5
121 | im = im.astype(np.uint8)
122 | # print("im", im.shape)
123 | im = np.transpose(im, (1, 2, 0))
124 | # print("im", im.shape)
125 | im = Image.fromarray(im)
126 | im.save(save_name)
127 |
128 | if __name__ == "__main__":
129 | # display_specific(bird_type_no=0, file_no=0) # old method
130 | display_specific(file_idx=14) # new method
131 |
132 | print("__"*80)
133 | # compare_embedding_quality(emb_idx_1=0, emb_idx_2=1, emb_no=0)
134 | ###* emb_idx < 8855, emb_no < 10
135 |
136 |
137 | print("__"*80)
138 | check_model(file_idx=14,
139 | model="../../old_outputs/output-3/model/netG_epoch_110.pth")
140 |
141 | plt.show()
--------------------------------------------------------------------------------
/src/engine.py:
--------------------------------------------------------------------------------
1 | """Train and Eval functions
2 |
3 | Authors:
4 | Abhiraj Tiwari (abhirajtiwari@gmail.com)
5 | Sahil Khose (sahilkhose18@gmail.com)
6 | """
7 | import util
8 |
9 | import torch
10 | import torch.nn as nn
11 |
12 | from torch.utils.tensorboard import summary
13 | from torch.utils.tensorboard import FileWriter
14 | from tqdm import tqdm
15 |
16 | def KL_loss(mu, logvar):
17 | # -0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
18 | KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
19 | KLD = torch.mean(KLD_element).mul_(-0.5)
20 | return KLD
21 |
22 | def disc_loss(disc, real_imgs, fake_imgs, real_labels, fake_labels, conditional_vector):
23 | loss_fn = nn.BCELoss()
24 | batch_size = real_imgs.shape[0]
25 | cond = conditional_vector.detach()
26 | fake_imgs = fake_imgs.detach()
27 |
28 | real_loss = loss_fn(disc(cond, real_imgs), real_labels)
29 | fake_loss = loss_fn(disc(cond, fake_imgs), fake_labels)
30 | wrong_loss = loss_fn(disc(cond[1:], real_imgs[:-1]), fake_labels[1:])
31 | loss = real_loss + (fake_loss+wrong_loss)*0.5
32 | return loss, real_loss, wrong_loss, fake_loss
33 |
34 | def gen_loss(disc, fake_imgs, real_labels, conditional_vector):
35 | loss_fn = nn.BCELoss()
36 | cond = conditional_vector.detach()
37 | fake_loss = loss_fn(disc(cond, fake_imgs), real_labels)
38 | return fake_loss
39 |
40 | def weights_init(m):
41 | classname = m.__class__.__name__
42 | if classname.find('Conv') != -1:
43 | m.weight.data.normal_(0.0, 0.02)
44 | elif classname.find('BatchNorm') != -1:
45 | m.weight.data.normal_(1.0, 0.02)
46 | m.bias.data.fill_(0)
47 | elif classname.find('Linear') != -1:
48 | m.weight.data.normal_(0.0, 0.02)
49 | if m.bias is not None:
50 | m.bias.data.fill_(0.0)
51 |
52 |
53 | def train_new_fn(data_loader, args, netG, netD, real_labels, fake_labels, noise, fixed_noise, optimizerD, optimizerG, epoch, count, summary_writer):
54 | errD_, errD_real_, errD_wrong_, errD_fake_, errG_, kl_loss_ = 0, 0, 0, 0, 0, 0
55 | for batch_id, data in tqdm(enumerate(data_loader), total=len(data_loader), desc=f"Train Epoch {epoch}/{args.TRAIN_MAX_EPOCH}"):
56 | ###* Prepare training data:
57 | text_emb, real_images = data
58 | text_emb = text_emb.to(args.device)
59 | real_images = real_images.to(args.device)
60 |
61 | ###* Generate fake images:
62 | noise.data.normal_(0, 1)
63 | _, fake_images, mu, logvar = netG(text_emb, noise)
64 |
65 | ###* Update D network:
66 | netD.zero_grad()
67 | errD, errD_real, errD_wrong, errD_fake = disc_loss(netD, real_images, fake_images, real_labels, fake_labels, text_emb)
68 | errD.backward()
69 | optimizerD.step()
70 |
71 | ###* Update G network:
72 | netG.zero_grad()
73 | errG = gen_loss(netD, fake_images, real_labels, text_emb)
74 | kl_loss = KL_loss(mu, logvar)
75 | errG_total = errG + kl_loss * args.TRAIN_COEFF_KL
76 | errG_total.backward()
77 | optimizerG.step()
78 |
79 | count += 1
80 |
81 | if batch_id % 10 == 0:
82 | summary_D = summary.scalar("D_loss", errD.data)
83 | summary_D_r = summary.scalar("D_loss_real", errD_real.data)
84 | summary_D_w = summary.scalar("D_loss_wrong", errD_wrong.data)
85 | summary_D_f = summary.scalar("D_loss_fake", errD_fake.data)
86 | summary_G = summary.scalar("G_loss", errG.data)
87 | summary_KL = summary.scalar("KL_loss", kl_loss.data)
88 |
89 | summary_writer.add_summary(summary_D, count)
90 | summary_writer.add_summary(summary_D_r, count)
91 | summary_writer.add_summary(summary_D_w, count)
92 | summary_writer.add_summary(summary_D_f, count)
93 | summary_writer.add_summary(summary_G, count)
94 | summary_writer.add_summary(summary_KL, count)
95 |
96 | ###* save the image result for each epoch:
97 | lr_fake, fake, _, _ = netG(text_emb, fixed_noise)
98 | util.save_img_results(real_images, fake, epoch, args)
99 | if lr_fake is not None:
100 | util.save_img_results(None, lr_fake, epoch, args)
101 |
102 | errD_ += errD
103 | errD_real_ += errD_real
104 | errD_wrong_ += errD_wrong
105 | errD_fake_ += errD_fake
106 | errG_ += errG
107 | kl_loss_ += kl_loss
108 |
109 | errD_ /= len(data_loader)
110 | errD_real_ /= len(data_loader)
111 | errD_wrong_ /= len(data_loader)
112 | errD_fake_ /= len(data_loader)
113 | errG_ /= len(data_loader)
114 | kl_loss_ /= len(data_loader)
115 |
116 | return errD_, errD_real_, errD_wrong_, errD_fake_, errG_, kl_loss_, count
117 |
118 |
119 |
120 |
121 | def eval_fn(data_loader, model, device, epoch):
122 | model.eval()
123 | fin_y = []
124 | fin_outputs = []
125 | LOSS = 0.
126 |
127 | with torch.no_grad():
128 | for batch_id, data in tqdm(enumerate(data_loader), total=len(data_loader)):
129 | text_embs, images = data
130 |
131 | # Loading it to device
132 | text_embs = text_embs.to(device, dtype=torch.float)
133 | images = images.to(device, dtype=torch.float)
134 |
135 | # getting outputs from model and calculating loss
136 | outputs = model(text_embs, images)
137 | loss = loss_fn(outputs, images) # TODO figure this out
138 | LOSS += loss
139 |
140 | # for calculating accuracy and other metrics # TODO figure this out
141 | fin_y.extend(images.view(-1, 1).cpu().detach().numpy().tolist())
142 | fin_outputs.extend(torch.sigmoid(outputs).cpu().detach().numpy().tolist())
143 |
144 | LOSS /= len(data_loader)
145 | return fin_outputs, fin_y, LOSS
146 |
--------------------------------------------------------------------------------
/src/train.py:
--------------------------------------------------------------------------------
1 | """Test a model and generate submission CSV.
2 |
3 | > python3 train.py --conf ../cfg/s1.yml
4 |
5 | Usage:
6 | > python train.py --load_path PATH --name NAME
7 | where
8 | > PATH is a path to a checkpoint (e.g., save/train/model-01/best.pth.tar)
9 | > NAME is a name to identify the train run
10 |
11 | Authors:
12 | Abhiraj Tiwari (abhirajtiwari@gmail.com)
13 | Sahil Khose (sahilkhose18@gmail.com)
14 | """
15 | import args
16 | # import config
17 | import dataset
18 | import engine
19 | import layers
20 | import util
21 |
22 | import matplotlib.pyplot as plt
23 | import numpy as np
24 | import os
25 | import pandas as pd
26 | import time
27 | import torch
28 | import torchfile
29 |
30 |
31 | from PIL import Image
32 | from sklearn import metrics
33 | from sklearn import model_selection
34 | from torch.utils.tensorboard import summary
35 | from torch.utils.tensorboard import FileWriter
36 | from torch.autograd import Variable
37 |
38 | print("__"*80)
39 | print("Imports Done...")
40 |
41 |
42 | def load_stage1(args):
43 | #* Init models and weights:
44 | from layers import Stage1Generator, Stage1Discriminator
45 | if args.embedding_type == "bert":
46 | netG = Stage1Generator(emb_dim=768)
47 | netD = Stage1Discriminator(emb_dim=768)
48 | else:
49 | netG = Stage1Generator(emb_dim=1024)
50 | netD = Stage1Discriminator(emb_dim=1024)
51 |
52 | netG.apply(engine.weights_init)
53 | netD.apply(engine.weights_init)
54 |
55 | #* Load saved model:
56 | if args.NET_G_path != "":
57 | netG.load_state_dict(torch.load(args.NET_G_path))
58 | print("__"*80)
59 | print("Generator loaded from: ", args.NET_G_path)
60 | print("__"*80)
61 | if args.NET_D_path != "":
62 | netD.load_state_dict(torch.load(args.NET_D_path))
63 | print("__"*80)
64 | print("Discriminator loaded from: ", args.NET_D_path)
65 | print("__"*80)
66 |
67 | #* Load on device:
68 | if args.device == "cuda":
69 | netG.cuda()
70 | netD.cuda()
71 |
72 | print("__"*80)
73 | print("GENERATOR:")
74 | print(netG)
75 | print("__"*80)
76 | print("DISCRIMINATOR:")
77 | print(netD)
78 | print("__"*80)
79 |
80 | return netG, netD
81 |
82 |
83 | def load_stage2(args):
84 | #* Init models and weights:
85 | from layers import Stage2Generator, Stage2Discriminator, Stage1Generator
86 | if args.embedding_type == "bert":
87 | Stage1_G = Stage1Generator(emb_dim=768)
88 | netG = Stage2Generator(Stage1_G, emb_dim=768)
89 | netD = Stage2Discriminator(emb_dim=768)
90 | else:
91 | Stage1_G = Stage1Generator(emb_dim=1024)
92 | netG = Stage2Generator(Stage1_G, emb_dim=1024)
93 | netD = Stage2Discriminator(emb_dim=1024)
94 | netG.apply(engine.weights_init)
95 | netD.apply(engine.weights_init)
96 |
97 | #* Load saved model:
98 | if args.NET_G_path != "":
99 | netG.load_state_dict(torch.load(args.NET_G_path))
100 | print("Generator loaded from: ", args.NET_G_path)
101 | elif args.STAGE1_G_path != "":
102 | netG.stage1_gen.load_state_dict(torch.load(args.STAGE1_G_path))
103 | print("Generator 1 loaded from: ", args.STAGE1_G_path)
104 | else:
105 | print("Please give the Stage 1 generator path")
106 | return
107 |
108 | if args.NET_D_path != "":
109 | netD.load_state_dict(torch.load(args.NET_D_path))
110 | print("Discriminator loaded from: ", args.NET_D_path)
111 |
112 | #* Load on device:
113 | if args.device == "cuda":
114 | netG.cuda()
115 | netD.cuda()
116 |
117 | print("__"*80)
118 | print(netG)
119 | print("__"*80)
120 | print(netD)
121 | print("__"*80)
122 |
123 | return netG, netD
124 |
125 |
126 | def run(args):
127 | if args.STAGE == 1:
128 | netG, netD = load_stage1(args)
129 | else:
130 | netG, netD = load_stage2(args)
131 |
132 | # Setting up device
133 | device = torch.device(args.device)
134 |
135 | # Load model
136 | netG.to(device)
137 | netD.to(device)
138 |
139 | nz = args.n_z
140 | batch_size = args.train_bs
141 | noise = Variable(torch.FloatTensor(batch_size, nz)).to(device)
142 | with torch.no_grad():
143 | fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1)).to(device) # volatile=True
144 | real_labels = Variable(torch.FloatTensor(batch_size).fill_(1)).to(device)
145 | fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0)).to(device)
146 |
147 | gen_lr = args.TRAIN_GEN_LR
148 | disc_lr = args.TRAIN_DISC_LR
149 |
150 | lr_decay_step = args.TRAIN_LR_DECAY_EPOCH
151 |
152 | optimizerD = torch.optim.Adam(netD.parameters(), lr=args.TRAIN_DISC_LR, betas=(0.5, 0.999))
153 |
154 | netG_para = []
155 | for p in netG.parameters():
156 | if p.requires_grad:
157 | netG_para.append(p)
158 | optimizerG = torch.optim.Adam(netG_para, lr=args.TRAIN_GEN_LR, betas=(0.5, 0.999))
159 |
160 | count = 0
161 |
162 | if args.embedding_type == "bert":
163 | training_set = dataset.CUBDataset(pickl_file=args.train_filenames, img_dir=args.images_dir, bert_emb=args.bert_annotations_dir, stage=args.STAGE)
164 | testing_set = dataset.CUBDataset(pickl_file=args.test_filenames, img_dir=args.images_dir, bert_emb=args.bert_annotations_dir, stage=args.STAGE)
165 | else:
166 | training_set = dataset.CUBDataset(pickl_file=args.train_filenames, img_dir=args.images_dir, cnn_emb=args.cnn_annotations_emb_train, stage=args.STAGE)
167 | testing_set = dataset.CUBDataset(pickl_file=args.test_filenames, img_dir=args.images_dir, cnn_emb=args.cnn_annotations_emb_test, stage=args.STAGE)
168 | train_data_loader = torch.utils.data.DataLoader(training_set, batch_size=args.train_bs, num_workers=args.train_workers)
169 | test_data_loader = torch.utils.data.DataLoader(testing_set, batch_size=args.test_bs, num_workers=args.test_workers)
170 | # util.check_dataset(training_set)
171 | # util.check_dataset(testing_set)
172 |
173 |
174 | # best_accuracy = 0
175 |
176 | util.make_dir(args.image_save_dir)
177 | util.make_dir(args.model_dir)
178 | util.make_dir(args.log_dir)
179 | summary_writer = FileWriter(args.log_dir)
180 |
181 | for epoch in range(1, args.TRAIN_MAX_EPOCH+1):
182 | print("__"*80)
183 | start_t = time.time()
184 |
185 | if epoch % lr_decay_step == 0 and epoch > 0:
186 | gen_lr *= 0.5
187 | for param_group in optimizerG.param_groups:
188 | param_group["lr"] = gen_lr
189 | disc_lr *= 0.5
190 | for param_group in optimizerD.param_groups:
191 | param_group["lr"] = disc_lr
192 |
193 | errD, errD_real, errD_wrong, errD_fake, errG, kl_loss, count = engine.train_new_fn(
194 | train_data_loader, args, netG, netD, real_labels, fake_labels,
195 | noise, fixed_noise, optimizerD, optimizerG, epoch, count, summary_writer)
196 |
197 | end_t = time.time()
198 |
199 | print(f"[{epoch}/{args.TRAIN_MAX_EPOCH}] Loss_D: {errD:.4f}, Loss_G: {errG:.4f}, Loss_KL: {kl_loss:.4f}, Loss_real: {errD_real:.4f}, Loss_wrong: {errD_wrong:.4f}, Loss_fake: {errD_fake:.4f}, Total Time: {end_t-start_t :.2f} sec")
200 | if epoch % args.TRAIN_SNAPSHOT_INTERVAL == 0 or epoch == 1:
201 | util.save_model(netG, netD, epoch, args)
202 |
203 | util.save_model(netG, netD, args.TRAIN_MAX_EPOCH, args)
204 | summary_writer.close()
205 |
206 |
207 | def sample(args, datapath):
208 | if args.STAGE == 1:
209 | netG, _ = load_stage1(args)
210 | else:
211 | netG, _ = load_stage2(args)
212 | netG.eval()
213 |
214 | ###* Load text embeddings generated from the encoder:
215 | t_file = torchfile.load(datapath)
216 | captions_list = t_file.raw_txt
217 | embeddings = np.concatenate(t_file.fea_txt, axis=0)
218 | num_embeddings = len(captions_list)
219 | print(f"Successfully load sentences from: {args.datapath}")
220 | print(f"Total number of sentences: {num_embeddings}")
221 | print(f"Num embeddings: {num_embeddings} {embeddings.shape}")
222 |
223 | ###* Path to save generated samples:
224 | save_dir = args.NET_G[:args.NET_G.find(".pth")]
225 | util.make_dir(save_dir)
226 |
227 | batch_size = np.minimum(num_embeddings, args.train_bs)
228 | nz = args.n_z
229 | noise = Variable(torch.FloatTensor(batch_size, nz))
230 | noise = noise.to(args.device)
231 | count = 0
232 | while count < num_embeddings:
233 | if count > 3000:
234 | break
235 | iend = count + batch_size
236 | if iend > num_embeddings:
237 | iend = num_embeddings
238 | count = num_embeddings - batch_size
239 | embeddings_batch = embeddings[count:iend]
240 | # captions_batch = captions_list[count:iend]
241 | text_embedding = Variable(torch.FloatTensor(embeddings_batch))
242 | text_embedding = text_embedding.to(args.device)
243 |
244 | ###* Generate fake images:
245 | noise.data.normal_(0, 1)
246 | _, fake_imgs, mu, logvar = netG(text_embedding, noise)
247 | for i in range(batch_size):
248 | save_name = f"{save_dir}/{count+i}.png"
249 | im = fake_imgs[i].data.cpu().numpy()
250 | im = (im + 1.0) * 127.5
251 | im = im.astype(np.uint8)
252 | # print("im", im.shape)
253 | im = np.transpose(im, (1, 2, 0))
254 | # print("im", im.shape)
255 | im = Image.fromarray(im)
256 | im.save(save_name)
257 | count += batch_size
258 |
259 | if __name__ == "__main__":
260 | args_ = args.get_all_args()
261 | args.print_args(args_)
262 | run(args_)
263 | # datapath = os.path.join(args_.datapath, "test/val_captions.t7")
264 | # sample(args_, datapath)
265 |
--------------------------------------------------------------------------------
/src/args.py:
--------------------------------------------------------------------------------
1 | """Command-line arguments.
2 |
3 | Authors:
4 | Abhiraj Tiwari (abhirajtiwari@gmail.com)
5 | Sahil Khose (sahilkhose18@gmail.com)
6 | """
7 | # > python3 args.py --conf ../cfg/s1.yml
8 | import argparse
9 | import yaml
10 | from json import dumps
11 |
12 | ########################################### MAIN FUNCTIONS ###########################################
13 |
14 | def get_all_args():
15 | return parse_args(get_parser())
16 |
17 |
18 | def print_args(args):
19 | print("__"*80)
20 | print(f'ARGUMENTS: \n{dumps(vars(args), indent=4, sort_keys=True)}')
21 | print("__"*80)
22 |
23 |
24 | def parse_args(parser):
25 | args = parser.parse_args()
26 | if args.config_file:
27 | data = yaml.load(args.config_file, Loader=yaml.Loader)
28 | delattr(args, "config_file")
29 | arg_dict = args.__dict__
30 | for key, value in data.items():
31 | if isinstance(value, list):
32 | for v in value:
33 | arg_dict[key].append(v)
34 | else:
35 | arg_dict[key] = value
36 | return args
37 |
38 |
39 | def get_parser():
40 | """Get all parameters"""
41 | parser = argparse.ArgumentParser("Arguments")
42 |
43 | add_conf_args(parser)
44 | add_train_args(parser)
45 | add_model_args(parser)
46 | add_data_args(parser)
47 |
48 | return parser
49 |
50 |
51 | ########################################### PARSER ARGUMENTS ###########################################
52 | def add_conf_args(parser):
53 | parser.add_argument("--config-file",
54 | type=argparse.FileType(mode="r"),
55 | dest="config_file",
56 | help="config yaml file to pass params")
57 |
58 |
59 | def add_train_args(parser):
60 | """ train.py """
61 | parser.add_argument("--NET_G_path",
62 | type=str,
63 | default="",
64 | help="Generator model loading path")
65 | parser.add_argument("--NET_D_path",
66 | type=str,
67 | default="",
68 | help="Discriminator model loading path")
69 | parser.add_argument("--STAGE1_G_path",
70 | type=str,
71 | default="../old_outputs/output_0/model/netG_epoch_120.pth",
72 | help="Stage 1 Generator model path for Stage 2 training")
73 | parser.add_argument("--train_bs",
74 | type=int,
75 | default=2,
76 | help="train batch size")
77 | parser.add_argument("--test_bs",
78 | type=int,
79 | default=1,
80 | help="test batch size")
81 | parser.add_argument("--train_workers",
82 | type=int,
83 | default=1,
84 | help="train num_workers")
85 | parser.add_argument("--test_workers",
86 | type=int,
87 | default=1,
88 | help="test num_workers")
89 | parser.add_argument("--TRAIN_GEN_LR",
90 | type=float,
91 | default=2e-4,
92 | help="train generator learning rate")
93 | parser.add_argument("--TRAIN_DISC_LR",
94 | type=float,
95 | default=2e-4,
96 | help="test discriminator learning rate")
97 | parser.add_argument("--TRAIN_LR_DECAY_EPOCH",
98 | type=int,
99 | default=20,
100 | help="train lr decay epoch")
101 | parser.add_argument("--TRAIN_MAX_EPOCH",
102 | type=int,
103 | default=120,
104 | help="train maximum epochs")
105 | parser.add_argument("--TRAIN_SNAPSHOT_INTERVAL",
106 | type=int,
107 | default=5,
108 | help="Snapshot interval")
109 | parser.add_argument("--TRAIN_COEFF_KL",
110 | type=float,
111 | default=2.0,
112 | help="train coefficient KL")
113 | parser.add_argument("--dataset_name",
114 | type=str,
115 | default="birds",
116 | help="birds/flowers: dataset name")
117 | parser.add_argument("--embedding_type",
118 | type=str,
119 | default="bert",
120 | help="bert/cnn-rnn: embedding type")
121 | parser.add_argument("--datapath",
122 | type=str,
123 | default="/gdrive/MyDrive/ganctober_training/output",
124 | help="datapath dir")
125 | parser.add_argument("--image_save_dir",
126 | type=str,
127 | default="/gdrive/MyDrive/ganctober_training/output/image/",
128 | help="Image save dir")
129 | parser.add_argument("--model_dir",
130 | type=str,
131 | default="/gdrive/MyDrive/ganctober_training/output/model/",
132 | help="Model save dir")
133 | parser.add_argument("--log_dir",
134 | type=str,
135 | default="/gdrive/MyDrive/ganctober_training/output/log/",
136 | help="Log dir for tensorboard")
137 | parser.add_argument("--VIS_COUNT",
138 | type=int,
139 | default=64,
140 | help="")
141 | parser.add_argument("--STAGE",
142 | type=int,
143 | default=2,
144 | help="Stage to train/eval (1/2)")
145 | parser.add_argument("--device",
146 | type=str,
147 | default="cuda", #! CHANGE THIS TO CUDA BEFORE TRAINING
148 | help="Device type: cuda/cpu")
149 |
150 |
151 | def add_model_args(parser):
152 | """
153 | Refer to StackGAN paper: https://arxiv.org/pdf/1612.03242.pdf
154 | for parameter names.
155 | """
156 | parser.add_argument("--n_g",
157 | type=int,
158 | default=128,
159 | help="")
160 | parser.add_argument("--n_z",
161 | type=int,
162 | default=100,
163 | help="")
164 | parser.add_argument("--m_d",
165 | type=int,
166 | default=4,
167 | help="")
168 | parser.add_argument("--m_g",
169 | type=int,
170 | default=16,
171 | help="")
172 | parser.add_argument("--n_d",
173 | type=int,
174 | default=128,
175 | help="")
176 | parser.add_argument("--w_0",
177 | type=int,
178 | default=64,
179 | help="")
180 | parser.add_argument("--h_0",
181 | type=int,
182 | default=256,
183 | help="")
184 | parser.add_argument("--w",
185 | type=int,
186 | default=256,
187 | help="")
188 | parser.add_argument("--h",
189 | type=int,
190 | default=256,
191 | help="")
192 |
193 |
194 | def add_data_args(parser):
195 | '''Get all data paths'''
196 | ###* Directories:
197 | parser.add_argument("--annotations_dir",
198 | type=str,
199 | default="../input/data/birds/text_c10/",
200 | help="Annotations dir path")
201 | parser.add_argument("--bert_annotations_dir",
202 | type=str,
203 | default="../input/data/birds/embeddings/",
204 | help="Annotations BERT embeddings dir path")
205 | parser.add_argument("--bert_path",
206 | type=str,
207 | default="../input/data/bert_base_uncased/",
208 | help="Bert model dir path")
209 | parser.add_argument("--images_dir",
210 | type=str,
211 | default="../input/data/CUB_200_2011/images/",
212 | help="Images dir path")
213 |
214 | ###* Files:
215 | add_birds_file_args(parser)
216 | add_cub_file_args(parser)
217 |
218 |
219 | def add_birds_file_args(parser):
220 | """
221 | Paths for files under input/data/birds
222 | files:
223 | - filenames.pickle
224 | List[str] {filename DOES NOT contain .jpg (or any) extenstion}
225 | To fetch the image: input/data/CUB_200_2011/images/.jpg == /.jpg
226 | To fetch the text annotation: input/data/birds/text_c10/.txt == /.txt
227 | To fetch the bert annotation: input/data/birds/embeddings//[0-9].pt == //[0-9].pt
228 | There are 2 such files:
229 | input/data/birds/train/filenames.pickle : len = 8855
230 | input/data/birds/test/filenames.pickle : len = 2933
231 | """
232 | parser.add_argument("--cnn_annotations_emb_train",
233 | type=str,
234 | default="../input/data/birds/train/char-CNN-RNN-embeddings.pickle",
235 | help="char-CNN-RNN-embeddings pickle file for train")
236 | parser.add_argument("--cnn_annotations_emb_test",
237 | type=str,
238 | default="../input/data/birds/test/char-CNN-RNN-embeddings.pickle",
239 | help="char-CNN-RNN-embeddings pickle file for test")
240 | parser.add_argument("--train_filenames",
241 | type=str,
242 | default="../input/data/birds/train/filenames.pickle",
243 | help="Pickle file path: filenames for train set")
244 | parser.add_argument("--test_filenames",
245 | type=str,
246 | default="../input/data/birds/test/filenames.pickle",
247 | help="Pickle file path: filenames for test set")
248 |
249 |
250 | def add_cub_file_args(parser):
251 | """
252 | Paths for files under input/data/CUB_200_2011
253 | files:
254 | - images.txt
255 | {image_name contains .jpg extenstion}
256 | To fetch the image: input/data/CUB_200_2011/images/
257 | number_of_images = 11788
258 |
259 | - train_test_split.txt
260 |
261 | 1: train, 0: test
262 | train size = 5994
263 | test size = 5794
264 |
265 | - bounding_boxes.txt
266 |
267 | """
268 | parser.add_argument("--images_id_file",
269 | type=str,
270 | default="../input/data/CUB_200_2011/images.txt",
271 | help="Text file path: mapping image id to image path") #
272 | parser.add_argument("--train_test_split_file",
273 | type=str,
274 | default="../input/data/CUB_200_2011/train_test_split.txt",
275 | help="Text file path: mapping image id to train/test split") #
276 | parser.add_argument("--bounding_boxes",
277 | type=str,
278 | default="../input/data/CUB_200_2011/bounding_boxes.txt",
279 | help="Text file path: mapping image id to train/test split") #
280 |
281 |
282 | if __name__ == "__main__":
283 | args = get_all_args()
284 | print_args(args)
285 | print("args:", args.train_bs)
286 |
--------------------------------------------------------------------------------
/src/layers.py:
--------------------------------------------------------------------------------
1 | """Assortment of layers for use in models.py.
2 | Refer to StackGAN paper: https://arxiv.org/pdf/1612.03242.pdf
3 | for variable names and working.
4 |
5 | Authors:
6 | Abhiraj Tiwari (abhirajtiwari@gmail.com)
7 | Sahil Khose (sahilkhose18@gmail.com)
8 | """
9 | import torch
10 | import torch.nn as nn
11 |
12 | from torch.autograd import Variable
13 |
14 |
15 | def conv3x3(in_channels, out_channels):
16 | """3x3 conv with same padding"""
17 | return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
18 |
19 |
20 | class ResBlock(nn.Module):
21 | def __init__(self, channel_num):
22 | super(ResBlock, self).__init__()
23 | self.block = nn.Sequential(
24 | conv3x3(channel_num, channel_num),
25 | nn.BatchNorm2d(channel_num),
26 | nn.ReLU(True),
27 | conv3x3(channel_num, channel_num),
28 | nn.BatchNorm2d(channel_num))
29 | self.relu = nn.ReLU(inplace=True)
30 |
31 | def forward(self, x):
32 | residual = x
33 | out = self.block(x)
34 | out += residual
35 | out = self.relu(out)
36 | return out
37 |
38 |
39 |
40 | def _downsample(in_channels, out_channels):
41 | return nn.Sequential(
42 | nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False),
43 | nn.BatchNorm2d(out_channels),
44 | nn.LeakyReLU(0.2, inplace=True)
45 | )
46 |
47 | def _upsample(in_channels, out_channels):
48 | return nn.Sequential(
49 | nn.Upsample(scale_factor=2, mode='nearest'),
50 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
51 | nn.BatchNorm2d(out_channels),
52 | nn.ReLU(inplace=True)
53 | )
54 |
55 | class CAug(nn.Module):
56 | """Module for conditional augmentation.
57 | Takes input as bert embeddings of annotations and sends output to Stage 1 and 2 generators.
58 | """
59 | def __init__ (self, emb_dim=768, n_g=128, device="cuda"): #! CHANGE THIS TO CUDA
60 | """
61 | @param emb_dim (int) : Size of annotation embeddings.
62 | @param n_g (int) : Dimension of mu, epsilon and c_0_hat
63 | @param device (torch.device) : cuda/cpu
64 | """
65 | super(CAug, self).__init__()
66 | self.emb_dim = emb_dim
67 | self.n_g = n_g
68 | self.fc = nn.Linear(self.emb_dim, self.n_g*2, bias=True) # To split in mu and sigma
69 | self.relu = nn.ReLU()
70 | self.device = device
71 |
72 | def forward(self, text_emb):
73 | """
74 | @param text_emb (torch.tensor): Text embedding. (batch, emb_dim)
75 | @returns c_0_hat (torch.tensor): Gaussian conditioning variable. (batch, n_g)
76 | """
77 | enc = self.relu(self.fc(text_emb)).squeeze(1) # (batch, n_g*2)
78 |
79 | mu = enc[:, :self.n_g] # (batch, n_g)
80 | logvar = enc[:, self.n_g:] # (batch, n_g)
81 |
82 | sigma = (logvar * 0.5).exp_() # exp(logvar * 0.5) = exp(log(var^0.5)) = sqrt(var) = std
83 |
84 | epsilon = Variable(torch.FloatTensor(sigma.size()).normal_())
85 |
86 | c_0_hat = epsilon.to(self.device) * sigma + mu # (batch, n_g)
87 |
88 | return c_0_hat, mu, logvar
89 |
90 | ######################### STAGE 1 #########################
91 |
92 | class Stage1Generator(nn.Module):
93 | """
94 | Stage 1 generator.
95 | Takes in input from Conditional Augmentation and outputs 64x64 image to Stage1Discrimantor.
96 | """
97 | def __init__(self, n_g=128, n_z=100, emb_dim=768):
98 | """
99 | @param n_g (int) : Dimension of c_0_hat.
100 | @param n_z (int) : Dimension of noise vector.
101 | """
102 | super(Stage1Generator, self).__init__()
103 | self.n_g = n_g
104 | self.n_z = n_z
105 | self.emb_dim = emb_dim
106 | self.inp_ch = self.n_g*8
107 |
108 | # (batch, bert_size) -> (batch, n_g)
109 | self.caug = CAug(emb_dim=self.emb_dim)
110 |
111 | # (batch, n_g + n_z) -> (batch, inp_ch * 4 * 4)
112 | self.fc = nn.Sequential(
113 | nn.Linear(self.n_g + self.n_z, self.inp_ch * 4 * 4, bias=False),
114 | nn.BatchNorm1d(self.inp_ch * 4 * 4),
115 | nn.ReLU(True)
116 | )
117 |
118 | # (batch, inp_ch, 4, 4) -> (batch, inp_ch//2, 8, 8)
119 | self.up1 = _upsample(self.inp_ch, self.inp_ch // 2)
120 | # -> (batch, inp_ch//4, 16, 16)
121 | self.up2 = _upsample(self.inp_ch // 2, self.inp_ch // 4)
122 | # -> (batch, inp_ch//8, 32, 32)
123 | self.up3 = _upsample(self.inp_ch // 4, self.inp_ch // 8)
124 | # -> (batch, inp_ch//16, 64, 64)
125 | self.up4 = _upsample(self.inp_ch // 8, self.inp_ch // 16)
126 |
127 | # -> (batch, 3, 64, 64)
128 | self.img = nn.Sequential(
129 | conv3x3(self.inp_ch // 16, 3),
130 | nn.Tanh()
131 | )
132 |
133 | def forward(self, text_emb, noise):
134 | """
135 | @param c_0_hat (torch.tensor) : Output of Conditional Augmentation (batch, n_g)
136 | @returns out (torch.tensor) : Generator 1 image output (batch, 3, 64, 64)
137 | """
138 | c_0_hat, mu, logvar = self.caug(text_emb)
139 |
140 | # -> (batch, n_g + n_z) (batch, 128 + 100)
141 | c_z = torch.cat((c_0_hat, noise), dim=1)
142 |
143 | # -> (batch, 1024 * 4 * 4)
144 | inp = self.fc(c_z)
145 |
146 | # -> (batch, 1024, 4, 4)
147 | inp = inp.view(-1, self.inp_ch, 4, 4)
148 |
149 | inp = self.up1(inp) # (batch, 512, 8, 8)
150 | inp = self.up2(inp) # (batch, 256, 16, 16)
151 | inp = self.up3(inp) # (batch, 128, 32, 32)
152 | inp = self.up4(inp) # (batch, 64, 64, 64)
153 |
154 | fake_img = self.img(inp) # (batch, 3, 64, 64)
155 | return None, fake_img, mu, logvar
156 |
157 | class Stage1Discriminator(nn.Module):
158 | """
159 | Stage 1 discriminator
160 | """
161 | def __init__(self, n_d=128, m_d=4, emb_dim=768, img_dim=64):
162 | super(Stage1Discriminator, self).__init__()
163 | self.n_d = n_d
164 | self.m_d = m_d
165 | self.emb_dim = emb_dim
166 |
167 | self.fc_for_text = nn.Linear(self.emb_dim, self.n_d)
168 | self.down_sample = nn.Sequential(
169 | # (batch, 3, 64, 64) -> (batch, img_dim, 32, 32)
170 | nn.Conv2d(3, img_dim, kernel_size=4, stride=2, padding=1, bias=False), # (batch, 64, 32, 32)
171 | nn.LeakyReLU(0.2, inplace=True),
172 | # -> (batch, img_dim * 2, 16, 16)
173 | _downsample(img_dim, img_dim*2), # (batch, 128, 16, 16)
174 | # -> (batch, img_dim * 4, 8, 8)
175 | _downsample(img_dim*2, img_dim*4), # (batch, 256, 8, 8)
176 | # -> (batch, img_dim * 8, 4, 4)
177 | _downsample(img_dim*4, img_dim*8) # (batch, 512, 4, 4)
178 | )
179 |
180 | self.out_logits = nn.Sequential(
181 | # (batch, img_dim*8 + n_d, 4, 4) -> (batch, img_dim*8, 4, 4)
182 | conv3x3(img_dim*8 + self.n_d, img_dim*8),
183 | nn.BatchNorm2d(img_dim*8),
184 | nn.LeakyReLU(0.2, inplace=True),
185 | # -> (batch, 1)
186 | nn.Conv2d(img_dim*8, 1, kernel_size=4, stride=4),
187 | nn.Sigmoid()
188 | )
189 |
190 | def forward(self, text_emb, img):
191 | # image encode
192 | enc = self.down_sample(img)
193 |
194 | # text emb
195 | compressed = self.fc_for_text(text_emb)
196 | compressed = compressed.unsqueeze(2).unsqueeze(3).repeat(1, 1, self.m_d, self.m_d)
197 |
198 | con = torch.cat((enc, compressed), dim=1)
199 |
200 | output = self.out_logits(con)
201 | return output.view(-1)
202 |
203 |
204 | ######################### STAGE 2 #########################
205 | class Stage2Generator(nn.Module):
206 | """
207 | Stage 2 generator.
208 | Takes in input from Conditional Augmentation and outputs 256x256 image to Stage2Discrimantor.
209 | """
210 | def __init__(self, stage1_gen, n_g=128, n_z=100, ef_size=128, n_res=4, emb_dim=768):
211 | """
212 | @param n_g (int) : Dimension of c_0_hat.
213 | """
214 | super(Stage2Generator, self).__init__()
215 | self.n_g = n_g
216 | self.n_z = n_z
217 | self.ef_size = ef_size
218 | self.n_res = n_res
219 | self.emb_dim = emb_dim
220 |
221 | self.stage1_gen = stage1_gen
222 | # Freezing the stage 1 generator:
223 | for param in self.stage1_gen.parameters():
224 | param.requires_grad = False
225 |
226 | # (batch, bert_size) -> (batch, n_g)
227 | self.caug = CAug(emb_dim=self.emb_dim)
228 |
229 | # -> (batch, n_g*4, 16, 16)
230 | self.encoder = nn.Sequential(
231 | conv3x3(3, n_g), # (batch, 128, 64, 64)
232 | nn.LeakyReLU(0.2, inplace=True), #? Paper: leaky, code: relu
233 | _downsample(n_g, n_g*2), # (batch, 256, 32, 32)
234 | _downsample(n_g*2, n_g*4) # (batch, 512, 16, 16)
235 | )
236 |
237 | # (batch, ef_size + n_g * 4, 16, 16) -> (batch, n_g * 4, 16, 16)
238 | # (batch, 128 + 512, 16, 16) -> (batch, 512, 16, 16)
239 | self.cat_conv = nn.Sequential(
240 | conv3x3(self.ef_size + self.n_g * 4, self.n_g * 4),
241 | nn.BatchNorm2d(self.n_g * 4),
242 | nn.ReLU(inplace=True)
243 | )
244 |
245 | # -> (batch, n_g * 4, 16, 16)
246 | # (batch, 512, 16, 16)
247 | self.residual = nn.Sequential(
248 | *[
249 | ResBlock(self.n_g * 4) for _ in range(self.n_res)
250 | ]
251 | )
252 |
253 | # -> (batch, n_g * 2, 32, 32)
254 | self.up1 = _upsample(n_g * 4, n_g * 2) # (batch, 256, 32, 32)
255 | # -> (batch, n_g, 64, 64)
256 | self.up2 = _upsample(n_g * 2, n_g) # (batch, 128, 64, 64)
257 | # -> (batch, n_g // 2, 128, 128)
258 | self.up3 = _upsample(n_g, n_g // 2) # (batch, 64, 128, 128)
259 | # -> (batch, n_g // 4, 256, 256)
260 | self.up4 = _upsample(n_g // 2, n_g // 4) # (batch, 32, 256, 256)
261 |
262 | # (batch, 3, 256, 256)
263 | self.img = nn.Sequential(
264 | conv3x3(n_g // 4, 3),
265 | nn.Tanh()
266 | )
267 |
268 | def forward(self, text_emb, noise):
269 | """
270 | @param c_0_hat (torch.tensor) : Output of Conditional Augmentation (batch, n_g)
271 | @param s1_image (torch.tensor) : Ouput of Stage 1 Generator (batch, 3, 64, 64)
272 | @returns out (torch.tensor) : Generator 2 image output (batch, 3, 256, 256)
273 | """
274 | _, stage1_img, _, _ = self.stage1_gen(text_emb, noise)
275 | stage1_img = stage1_img.detach()
276 |
277 | encoded_img = self.encoder(stage1_img)
278 |
279 | c_0_hat, mu, logvar = self.caug(text_emb)
280 | c_0_hat = c_0_hat.unsqueeze(2).unsqueeze(3).repeat(1, 1, 16, 16)
281 |
282 | # -> (batch, ef_size + n_g * 4, 16, 16) # (batch, 640, 16, 16)
283 | concat_out = torch.cat((encoded_img, c_0_hat), dim=1)
284 |
285 | # -> (batch, n_g * 4, 16, 16)
286 | h_out = self.cat_conv(concat_out)
287 | h_out = self.residual(h_out)
288 |
289 | h_out = self.up1(h_out)
290 | h_out = self.up2(h_out)
291 | h_out = self.up3(h_out)
292 | # -> (batch, ng // 4, 256, 256)
293 | h_out = self.up4(h_out)
294 |
295 | # -> (batch, 3, 256, 256)
296 | fake_img = self.img(h_out)
297 |
298 | return stage1_img, fake_img, mu, logvar
299 |
300 |
301 | class Stage2Discriminator(nn.Module):
302 | """
303 | Stage 2 discriminator
304 | """
305 |
306 | def __init__(self, n_d=128, m_d=4, emb_dim=768, img_dim=256):
307 | super(Stage2Discriminator, self).__init__()
308 | self.n_d = n_d
309 | self.m_d = m_d
310 | self.emb_dim = emb_dim
311 |
312 | self.fc_for_text = nn.Linear(self.emb_dim, self.n_d)
313 | self.down_sample = nn.Sequential(
314 | # (batch, 3, 64, 64) -> (batch, img_dim//4, 128, 128)
315 | nn.Conv2d(3, img_dim//4, kernel_size=4, stride=2, padding=1, bias=False), # (batch, 64, 128, 128)
316 | nn.LeakyReLU(0.2, inplace=True),
317 | # -> (batch, img_dim//2, 64, 64)
318 | _downsample(img_dim//4, img_dim//2), # (batch, 128, 64, 64)
319 | # -> (batch, img_dim, 32, 32)
320 | _downsample(img_dim//2, img_dim), # (batch, 256, 32, 32)
321 | # -> (batch, img_dim*2, 16, 16)
322 | _downsample(img_dim, img_dim*2), # (batch, 512, 16, 16)
323 | # -> (batch, img_dim*4, 8, 8)
324 | _downsample(img_dim*2, img_dim*4), # (batch, 1024, 8, 8)
325 | # -> (batch, img_dim*8, 4, 4)
326 | _downsample(img_dim*4, img_dim*8), # (batch, 2096, 4, 4)
327 | # -> (batch, img_dim*4, 4, 4)
328 | conv3x3(img_dim*8, img_dim*4), # (batch, 1024, 4, 4)
329 | nn.BatchNorm2d(img_dim*4),
330 | nn.LeakyReLU(0.2, inplace=True),
331 | # -> (batch, img_dim*2, 4, 4)
332 | conv3x3(img_dim * 4, img_dim * 2), # (batch, 512, 4, 4)
333 | nn.BatchNorm2d(img_dim * 2),
334 | nn.LeakyReLU(0.2, inplace=True)
335 | )
336 |
337 | self.out_logits = nn.Sequential(
338 | # (batch, img_dim*2 + n_d, 4, 4) -> (batch, img_dim*2, 4, 4)
339 | conv3x3(img_dim*2 + self.n_d, img_dim*2),
340 | nn.BatchNorm2d(img_dim*2),
341 | nn.LeakyReLU(0.2, inplace=True),
342 | # -> (batch, 1)
343 | nn.Conv2d(img_dim*2, 1, kernel_size=4, stride=4),
344 | nn.Sigmoid()
345 | )
346 |
347 | def forward(self, text_emb, img):
348 | # image encode
349 | enc = self.down_sample(img)
350 |
351 | # text emb
352 | compressed = self.fc_for_text(text_emb)
353 | compressed = compressed.unsqueeze(2).unsqueeze(3).repeat(1, 1, self.m_d, self.m_d)
354 |
355 | con = torch.cat((enc, compressed), dim=1)
356 |
357 | output = self.out_logits(con)
358 | return output.view(-1)
359 |
360 | ######################### #########################
361 |
362 |
363 | if __name__ == "__main__":
364 | batch_size = 2
365 | n_z = 100
366 | emb_dim = 1024 # 768
367 | emb = torch.randn((batch_size, emb_dim))
368 | noise = torch.empty((batch_size, n_z)).normal_()
369 |
370 | generator1 = Stage1Generator(emb_dim=emb_dim)
371 | generator2 = Stage2Generator(generator1, emb_dim=emb_dim)
372 |
373 | discriminator1 = Stage1Discriminator(emb_dim=emb_dim)
374 | discriminator2 = Stage2Discriminator(emb_dim=emb_dim)
375 |
376 |
377 | _, gen1, _, _ = generator1(emb, noise)
378 | print("output1 image dimensions :", gen1.size()) # (batch_size, 3, 64, 64)
379 | assert gen1.shape == (batch_size, 3, 64, 64)
380 | print()
381 |
382 | disc1 = discriminator1(emb, gen1)
383 | print("output1 discriminator", disc1.size()) # (batch_size)
384 | # assert disc1.shape == (batch_size)
385 | print()
386 |
387 | _, gen2, _, _ = generator2(emb, noise)
388 | print("output2 image dimensions :", gen2.size()) # (batch_size, 3, 256, 256)
389 | assert gen2.shape == (batch_size, 3, 256, 256)
390 | print()
391 |
392 | disc2 = discriminator2(emb, gen2)
393 | print("output2 discriminator", disc2.size()) # (batch_size)
394 | # assert disc2.shape == (batch_size)
395 | print()
396 |
397 | ca = CAug(emb_dim=emb_dim, n_g=128, device='cpu')
398 | out_ca, _, _ = ca(emb)
399 | print("Conditional Aug output size: ", out_ca.size()) # (batch_size, 128)
400 | assert out_ca.shape == (batch_size, 128)
401 |
402 |
403 | ###* Checking init weights
404 | # import engine
405 | # netG = Stage1Generator()
406 | # netG.apply(engine.weights_init)
407 | pass
408 |
--------------------------------------------------------------------------------