├── data
└── readme.txt
├── code
├── src
│ ├── readme.txt
│ ├── sample.ipynb
│ ├── test.py
│ └── train.py
├── models
│ ├── __init__.py
│ ├── inception.py
│ └── GALIP.py
├── cfg
│ ├── coco.yml
│ └── birds.yml
├── scripts
│ ├── test.sh
│ └── train.sh
└── lib
│ ├── perpare.py
│ ├── datasets.py
│ ├── utils.py
│ └── modules.py
├── logo.jpeg
├── results.jpg
├── requirements.txt
├── LICENSE
├── .gitignore
└── README.md
/data/readme.txt:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/code/src/readme.txt:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/code/models/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/logo.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tobran/GALIP/HEAD/logo.jpeg
--------------------------------------------------------------------------------
/results.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tobran/GALIP/HEAD/results.jpg
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | tqdm
2 | numpy
3 | torch==1.11.0
4 | torchvision
5 | easydict
6 | pandas
7 | jupyter
8 | pyyaml
9 | ipykernel
10 | scipy
11 | tensorboard
12 | ftfy
13 | regex
14 | tqdm
15 |
--------------------------------------------------------------------------------
/code/cfg/coco.yml:
--------------------------------------------------------------------------------
1 | CONFIG_NAME: 'coco'
2 | dataset_name: 'coco'
3 | data_dir: '../data/coco'
4 |
5 | imsize: 256
6 | z_dim: 100
7 | cond_dim: 512
8 | manual_seed: 100
9 | cuda: True
10 |
11 | clip4evl: {'src':"clip", 'type':'ViT-B/32'}
12 | clip4trn: {'src':"clip", 'type':'ViT-B/32'}
13 | clip4text: {'src':"clip", 'type':'ViT-B/32'}
14 |
15 | stamp: 'normal'
16 | state_epoch: 0
17 | max_epoch: 3005
18 | batch_size: 16
19 | gpu_id: 0
20 | nf: 64
21 | ch_size: 3
22 |
23 | scaler_min: 64
24 | growth_interval: 2000
25 | lr_g: 0.0001
26 | lr_d: 0.0004
27 | sim_w: 4.0
28 |
29 | gen_interval: 1 #1
30 | test_interval: 5 #5
31 | save_interval: 5
32 |
33 | sample_times: 1
34 | npz_path: '../data/coco/npz/coco_val256_FIDK0.npz'
35 | log_dir: 'new'
36 |
--------------------------------------------------------------------------------
/code/cfg/birds.yml:
--------------------------------------------------------------------------------
1 | CONFIG_NAME: 'bird'
2 | dataset_name: 'birds'
3 | data_dir: '../data/birds'
4 |
5 | imsize: 256
6 | z_dim: 100
7 | cond_dim: 512
8 | manual_seed: 100
9 | cuda: True
10 |
11 | clip4evl: {'src':"clip", 'type':'ViT-B/32'}
12 | clip4trn: {'src':"clip", 'type':'ViT-B/32'}
13 | clip4text: {'src':"clip", 'type':'ViT-B/32'}
14 |
15 | stamp: 'normal'
16 | state_epoch: 0
17 | max_epoch: 1502
18 | batch_size: 16
19 | gpu_id: 0
20 | nf: 64
21 | ch_size: 3
22 |
23 | scaler_min: 64
24 | growth_interval: 2000
25 | lr_g: 0.0001
26 | lr_d: 0.0004
27 | sim_w: 4.0
28 |
29 | gen_interval: 5 #1
30 | test_interval: 20 #5
31 | save_interval: 20
32 |
33 | sample_times: 12
34 | npz_path: '../data/birds/npz/bird_val256_FIDK0.npz'
35 | log_dir: 'new'
36 |
--------------------------------------------------------------------------------
/code/scripts/test.sh:
--------------------------------------------------------------------------------
1 | cfg=$1
2 | batch_size=64
3 |
4 | pretrained_model='./saved_models/data/model_save_file/xxx.pth'
5 | multi_gpus=True
6 | mixed_precision=True
7 |
8 | nodes=1
9 | num_workers=8
10 | master_port=11277
11 | stamp=gpu${nodes}MP_${mixed_precision}
12 |
13 | CUDA_VISIBLE_DEVICES=0 python -m torch.distributed.launch --nproc_per_node=$nodes --master_port=$master_port src/test.py \
14 | --stamp $stamp \
15 | --cfg $cfg \
16 | --mixed_precision $mixed_precision \
17 | --batch_size $batch_size \
18 | --num_workers $num_workers \
19 | --multi_gpus $multi_gpus \
20 | --pretrained_model_path $pretrained_model \
21 |
--------------------------------------------------------------------------------
/code/scripts/train.sh:
--------------------------------------------------------------------------------
1 | cfg=$1
2 | batch_size=64
3 |
4 | state_epoch=1
5 | pretrained_model_path='./saved_models/data/model_save_file'
6 | log_dir='new'
7 |
8 | multi_gpus=True
9 | mixed_precision=True
10 |
11 | nodes=8
12 | num_workers=8
13 | master_port=11266
14 | stamp=gpu${nodes}MP_${mixed_precision}
15 |
16 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node=$nodes --master_port=$master_port src/train.py \
17 | --stamp $stamp \
18 | --cfg $cfg \
19 | --mixed_precision $mixed_precision \
20 | --log_dir $log_dir \
21 | --batch_size $batch_size \
22 | --state_epoch $state_epoch \
23 | --num_workers $num_workers \
24 | --multi_gpus $multi_gpus \
25 | --pretrained_model_path $pretrained_model_path \
26 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 MingTao
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib64/
18 | parts/
19 | sdist/
20 | var/
21 | wheels/
22 | pip-wheel-metadata/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 |
53 | # Translations
54 | *.mo
55 | *.pot
56 |
57 | # Django stuff:
58 | *.log
59 | local_settings.py
60 | db.sqlite3
61 | db.sqlite3-journal
62 |
63 | # Flask stuff:
64 | instance/
65 | .webassets-cache
66 |
67 | # Scrapy stuff:
68 | .scrapy
69 |
70 | # Sphinx documentation
71 | docs/_build/
72 |
73 | # PyBuilder
74 | target/
75 |
76 | # Jupyter Notebook
77 | .ipynb_checkpoints
78 |
79 | # IPython
80 | profile_default/
81 | ipython_config.py
82 |
83 | # pyenv
84 | .python-version
85 |
86 | # pipenv
87 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
88 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
89 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
90 | # install all needed dependencies.
91 | #Pipfile.lock
92 |
93 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
94 | __pypackages__/
95 |
96 | # Celery stuff
97 | celerybeat-schedule
98 | celerybeat.pid
99 |
100 | # SageMath parsed files
101 | *.sage.py
102 |
103 | # Environments
104 | .env
105 | .venv
106 | env/
107 | venv/
108 | ENV/
109 | env.bak/
110 | venv.bak/
111 |
112 | # Spyder project settings
113 | .spyderproject
114 | .spyproject
115 |
116 | # Rope project settings
117 | .ropeproject
118 |
119 | # mkdocs documentation
120 | /site
121 |
122 | # mypy
123 | .mypy_cache/
124 | .dmypy.json
125 | dmypy.json
126 |
127 | # Pyre type checker
128 | .pyre/
129 |
130 | logs/
131 | imgs/
132 | samples/
133 | saved_models/
134 | data/*
135 | __pycache__
136 | tmp/
--------------------------------------------------------------------------------
/code/src/sample.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import torch\n",
10 | "import os\n",
11 | "from PIL import Image\n",
12 | "import clip\n",
13 | "import os.path as osp\n",
14 | "import os, sys\n",
15 | "import torchvision.utils as vutils\n",
16 | "sys.path.insert(0, '../')\n",
17 | "\n",
18 | "from lib.utils import load_model_weights,mkdir_p\n",
19 | "from models.GALIP import NetG, CLIP_TXT_ENCODER"
20 | ]
21 | },
22 | {
23 | "cell_type": "code",
24 | "execution_count": 2,
25 | "metadata": {},
26 | "outputs": [],
27 | "source": [
28 | "device = 'cpu' # 'cpu' # 'cuda:0'\n",
29 | "CLIP_text = \"ViT-B/32\"\n",
30 | "clip_model, preprocess = clip.load(\"ViT-B/32\", device=device)\n",
31 | "clip_model = clip_model.eval()"
32 | ]
33 | },
34 | {
35 | "cell_type": "code",
36 | "execution_count": 3,
37 | "metadata": {},
38 | "outputs": [],
39 | "source": [
40 | "text_encoder = CLIP_TXT_ENCODER(clip_model).to(device)\n",
41 | "netG = NetG(64, 100, 512, 256, 3, False, clip_model).to(device)\n",
42 | "path = '../saved_models/pretrained/pre_cc12m.pth'\n",
43 | "checkpoint = torch.load(path, map_location=torch.device('cpu'))\n",
44 | "netG = load_model_weights(netG, checkpoint['model']['netG'], multi_gpus=False)"
45 | ]
46 | },
47 | {
48 | "cell_type": "code",
49 | "execution_count": 4,
50 | "metadata": {},
51 | "outputs": [],
52 | "source": [
53 | "batch_size = 8\n",
54 | "noise = torch.randn((batch_size, 100)).to(device)"
55 | ]
56 | },
57 | {
58 | "cell_type": "code",
59 | "execution_count": 5,
60 | "metadata": {},
61 | "outputs": [],
62 | "source": [
63 | "captions = ['Line drawing illustration of a kawaii cute ghost.']"
64 | ]
65 | },
66 | {
67 | "cell_type": "code",
68 | "execution_count": 6,
69 | "metadata": {},
70 | "outputs": [],
71 | "source": [
72 | "mkdir_p('./samples')"
73 | ]
74 | },
75 | {
76 | "cell_type": "code",
77 | "execution_count": 7,
78 | "metadata": {},
79 | "outputs": [],
80 | "source": [
81 | "# generate from text\n",
82 | "with torch.no_grad():\n",
83 | " for i in range(len(captions)):\n",
84 | " caption = captions[i]\n",
85 | " tokenized_text = clip.tokenize([caption]).to(device)\n",
86 | " sent_emb, word_emb = text_encoder(tokenized_text)\n",
87 | " sent_emb = sent_emb.repeat(batch_size,1)\n",
88 | " fake_imgs = netG(noise,sent_emb,eval=True).float()\n",
89 | " name = f'{captions[i].replace(\" \", \"-\")}'\n",
90 | " vutils.save_image(fake_imgs.data, '../samples/%s.png'%(name), nrow=8, value_range=(-1, 1), normalize=True)"
91 | ]
92 | }
93 | ],
94 | "metadata": {
95 | "kernelspec": {
96 | "display_name": "dfgan",
97 | "language": "python",
98 | "name": "python3"
99 | },
100 | "language_info": {
101 | "codemirror_mode": {
102 | "name": "ipython",
103 | "version": 3
104 | },
105 | "file_extension": ".py",
106 | "mimetype": "text/x-python",
107 | "name": "python",
108 | "nbconvert_exporter": "python",
109 | "pygments_lexer": "ipython3",
110 | "version": "3.9.0"
111 | },
112 | "orig_nbformat": 4,
113 | "vscode": {
114 | "interpreter": {
115 | "hash": "849434eb86c3997df801551b732438d01b491303b69c29efcf332721ce6d8430"
116 | }
117 | }
118 | },
119 | "nbformat": 4,
120 | "nbformat_minor": 2
121 | }
122 |
--------------------------------------------------------------------------------
/code/lib/perpare.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | import os.path as osp
3 | import numpy as np
4 | from PIL import Image
5 | from tqdm import tqdm, trange
6 | import torch
7 | import torch.nn.functional as F
8 | import torchvision.transforms as transforms
9 | import torchvision.utils as vutils
10 | from torchvision.utils import save_image, make_grid
11 | from torch.utils.data import DataLoader, random_split
12 | from torch.utils.data.distributed import DistributedSampler
13 | import clip
14 | import importlib
15 | from lib.utils import choose_model
16 |
17 |
18 | ########### preparation ############
19 | def load_clip(clip_info, device):
20 | import clip as clip
21 | model = clip.load(clip_info['type'], device=device)[0]
22 | return model
23 |
24 |
25 | def prepare_models(args):
26 | device = args.device
27 | local_rank = args.local_rank
28 | multi_gpus = args.multi_gpus
29 | CLIP4trn = load_clip(args.clip4trn, device).eval()
30 | CLIP4evl = load_clip(args.clip4evl, device).eval()
31 | NetG,NetD,NetC,CLIP_IMG_ENCODER,CLIP_TXT_ENCODER = choose_model(args.model)
32 | # image encoder
33 | CLIP_img_enc = CLIP_IMG_ENCODER(CLIP4trn).to(device)
34 | for p in CLIP_img_enc.parameters():
35 | p.requires_grad = False
36 | CLIP_img_enc.eval()
37 | # text encoder
38 | CLIP_txt_enc = CLIP_TXT_ENCODER(CLIP4trn).to(device)
39 | for p in CLIP_txt_enc.parameters():
40 | p.requires_grad = False
41 | CLIP_txt_enc.eval()
42 | # GAN models
43 | netG = NetG(args.nf, args.z_dim, args.cond_dim, args.imsize, args.ch_size, args.mixed_precision, CLIP4trn).to(device)
44 | netD = NetD(args.nf, args.imsize, args.ch_size, args.mixed_precision).to(device)
45 | netC = NetC(args.nf, args.cond_dim, args.mixed_precision).to(device)
46 | if (args.multi_gpus) and (args.train):
47 | print("Let's use", torch.cuda.device_count(), "GPUs!")
48 | netG = torch.nn.parallel.DistributedDataParallel(netG, broadcast_buffers=False,
49 | device_ids=[local_rank],
50 | output_device=local_rank, find_unused_parameters=True)
51 | netD = torch.nn.parallel.DistributedDataParallel(netD, broadcast_buffers=False,
52 | device_ids=[local_rank],
53 | output_device=local_rank, find_unused_parameters=True)
54 | netC = torch.nn.parallel.DistributedDataParallel(netC, broadcast_buffers=False,
55 | device_ids=[local_rank],
56 | output_device=local_rank, find_unused_parameters=True)
57 | return CLIP4trn, CLIP4evl, CLIP_img_enc, CLIP_txt_enc, netG, netD, netC
58 |
59 |
60 | def prepare_dataset(args, split, transform):
61 | if args.ch_size!=3:
62 | imsize = 256
63 | else:
64 | imsize = args.imsize
65 | if transform is not None:
66 | image_transform = transform
67 | else:
68 | image_transform = transforms.Compose([
69 | transforms.Resize(int(imsize * 76 / 64)),
70 | transforms.RandomCrop(imsize),
71 | transforms.RandomHorizontalFlip(),
72 | ])
73 | from lib.datasets import TextImgDataset as Dataset
74 | dataset = Dataset(split=split, transform=image_transform, args=args)
75 | return dataset
76 |
77 |
78 | def prepare_datasets(args, transform):
79 | # train dataset
80 | train_dataset = prepare_dataset(args, split='train', transform=transform)
81 | # test dataset
82 | val_dataset = prepare_dataset(args, split='test', transform=transform)
83 | return train_dataset, val_dataset
84 |
85 |
86 | def prepare_dataloaders(args, transform=None):
87 | batch_size = args.batch_size
88 | num_workers = args.num_workers
89 | train_dataset, valid_dataset = prepare_datasets(args, transform)
90 | # train dataloader
91 | if args.multi_gpus==True:
92 | train_sampler = DistributedSampler(train_dataset)
93 | train_dataloader = torch.utils.data.DataLoader(
94 | train_dataset, batch_size=batch_size, drop_last=True,
95 | num_workers=num_workers, sampler=train_sampler)
96 | else:
97 | train_sampler = None
98 | train_dataloader = torch.utils.data.DataLoader(
99 | train_dataset, batch_size=batch_size, drop_last=True,
100 | num_workers=num_workers, shuffle='True')
101 | # valid dataloader
102 | if args.multi_gpus==True:
103 | valid_sampler = DistributedSampler(valid_dataset)
104 | valid_dataloader = torch.utils.data.DataLoader(
105 | valid_dataset, batch_size=batch_size, drop_last=True,
106 | num_workers=num_workers, sampler=valid_sampler)
107 | else:
108 | valid_dataloader = torch.utils.data.DataLoader(
109 | valid_dataset, batch_size=batch_size, drop_last=True,
110 | num_workers=num_workers, shuffle='True')
111 | return train_dataloader, valid_dataloader, \
112 | train_dataset, valid_dataset, train_sampler
113 |
114 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | 
2 | [](https://github.com/tobran/GALIP/blob/master/LICENSE.md)
3 | 
4 | 
5 | 
6 | 
7 | []((https://github.com/tobran/GALIP/graphs/commit-activity))
8 | 
9 |
10 | # GALIP: Generative Adversarial CLIPs for Text-to-Image Synthesis (CVPR 2023)
11 |
12 |
13 |
14 |
15 |
16 | # A high-quality, fast, and efficient text-to-image model
17 |
18 | Official Pytorch implementation for our paper [GALIP: Generative Adversarial CLIPs for Text-to-Image Synthesis](https://arxiv.org/abs/2301.12959) by [Ming Tao](https://scholar.google.com/citations?user=5GlOlNUAAAAJ), [Bing-Kun Bao](https://scholar.google.com/citations?user=lDppvmoAAAAJ&hl=en), [Hao Tang](https://scholar.google.com/citations?user=9zJkeEMAAAAJ&hl=en), [Changsheng Xu](https://scholar.google.com/citations?user=hI9NRDkAAAAJ).
19 |
20 |
21 | Generated Images
22 |
23 |
24 |
25 |
26 |
27 |
28 | ## Requirements
29 | - python 3.9
30 | - Pytorch 1.9
31 | - At least 1x24GB 3090 GPU (for training)
32 | - Only CPU (for sampling)
33 |
34 | GALIP is a small and fast generative model which can generate multiple pictures in one second even on the CPU.
35 | ## Installation
36 |
37 | Clone this repo.
38 | ```
39 | git clone https://github.com/tobran/GALIP
40 | pip install -r requirements.txt
41 | ```
42 | Install [CLIP](https://github.com/openai/CLIP)
43 |
44 | ## Preparation (Same as DF-GAN)
45 | ### Datasets
46 | 1. Download the preprocessed metadata for [birds](https://drive.google.com/file/d/1I6ybkR7L64K8hZOraEZDuHh0cCJw5OUj/view?usp=sharing) [coco](https://drive.google.com/file/d/15Fw-gErCEArOFykW3YTnLKpRcPgI_3AB/view?usp=sharing) and extract them to `data/`
47 | 2. Download the [birds](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html) image data. Extract them to `data/birds/`
48 | 3. Download [coco2014](http://cocodataset.org/#download) dataset and extract the images to `data/coco/images/`
49 |
50 | ## Training
51 | ```
52 | cd GALIP/code/
53 | ```
54 | ### Train the GALIP model
55 | - For bird dataset: `bash scripts/train.sh ./cfg/bird.yml`
56 | - For coco dataset: `bash scripts/train.sh ./cfg/coco.yml`
57 | ### Resume training process
58 | If your training process is interrupted unexpectedly, set **state_epoch**, **log_dir**, and **pretrained_model_path** in train.sh to resume training.
59 |
60 | ### TensorBoard
61 | Our code supports automate FID evaluation during training, the results are stored in TensorBoard files under ./logs. You can change the test interval by changing **test_interval** in the YAML file.
62 |
63 | - For bird dataset: `tensorboard --logdir=./code/logs/bird/train --port 8166`
64 | - For coco dataset: `tensorboard --logdir=./code/logs/coco/train --port 8177`
65 |
66 |
67 | ## Evaluation
68 |
69 | ### Download Pretrained Model
70 | - [GALIP for COCO](https://drive.google.com/file/d/1gbfwDeD7ftZmdOFxfffCjKCyYfF4ptdl/view?usp=sharing). Download and save it to `./code/saved_models/pretrained/`
71 | - [GALIP for CC12M](https://drive.google.com/file/d/1VnONvNRjuyHTzuLKBbozZ38-WIt7XZMC/view?usp=sharing). Download and save it to `./code/saved_models/pretrained/`
72 |
73 | ### Evaluate GALIP models
74 |
75 | ```
76 | cd GALIP/code/
77 | ```
78 | set **pretrained_model** in test.sh
79 | - For bird dataset: `bash scripts/test.sh ./cfg/bird.yml`
80 | - For COCO dataset: `bash scripts/test.sh ./cfg/coco.yml`
81 | - For CC12M (zero-shot on COCO) dataset: `bash scripts/test.sh ./cfg/coco.yml`
82 |
83 | ### Performance
84 | The released model achieves better performance than the paper version.
85 |
86 |
87 | | Model | COCO-FID↓ | COCO-CS↑ | CC12M-ZFID↓ |
88 | | --- | --- | --- | --- |
89 | | GALIP(paper) | 5.85 | 0.3338 | 12.54 |
90 | | GALIP(released) | **5.01** | **0.3379** | **12.54** |
91 |
92 |
93 | ## Sampling
94 |
95 | ### Synthesize images from your text descriptions
96 | - the sample.ipynb can be used to sample
97 |
98 | ---
99 | ### Citing GALIP
100 |
101 | If you find GALIP useful in your research, please consider citing our paper:
102 | ```
103 |
104 | @inproceedings{tao2023galip,
105 | title={GALIP: Generative Adversarial CLIPs for Text-to-Image Synthesis},
106 | author={Tao, Ming and Bao, Bing-Kun and Tang, Hao and Xu, Changsheng},
107 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
108 | pages={14214--14223},
109 | year={2023}
110 | }
111 |
112 | ```
113 | The code is released for academic research use only. For commercial use, please contact Ming Tao (陶明) (mingtao2000@126.com).
114 |
115 |
116 | **Reference**
117 | - [DF-GAN: A Simple and Effective Baseline for Text-to-Image Synthesis](https://arxiv.org/abs/2008.05865) [[code]](https://github.com/tobran/DF-GAN)
118 |
--------------------------------------------------------------------------------
/code/models/inception.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | from torchvision import models
4 |
5 |
6 | class InceptionV3(nn.Module):
7 | """Pretrained InceptionV3 network returning feature maps"""
8 |
9 | # Index of default block of inception to return,
10 | # corresponds to output of final average pooling
11 | DEFAULT_BLOCK_INDEX = 3
12 |
13 | # Maps feature dimensionality to their output blocks indices
14 | BLOCK_INDEX_BY_DIM = {
15 | 64: 0, # First max pooling features
16 | 192: 1, # Second max pooling featurs
17 | 768: 2, # Pre-aux classifier features
18 | 2048: 3 # Final average pooling features
19 | }
20 |
21 | def __init__(self,
22 | output_blocks=[DEFAULT_BLOCK_INDEX],
23 | resize_input=True,
24 | normalize_input=True,
25 | requires_grad=False):
26 | """Build pretrained InceptionV3
27 |
28 | Parameters
29 | ----------
30 | output_blocks : list of int
31 | Indices of blocks to return features of. Possible values are:
32 | - 0: corresponds to output of first max pooling
33 | - 1: corresponds to output of second max pooling
34 | - 2: corresponds to output which is fed to aux classifier
35 | - 3: corresponds to output of final average pooling
36 | resize_input : bool
37 | If true, bilinearly resizes input to width and height 299 before
38 | feeding input to model. As the network without fully connected
39 | layers is fully convolutional, it should be able to handle inputs
40 | of arbitrary size, so resizing might not be strictly needed
41 | normalize_input : bool
42 | If true, normalizes the input to the statistics the pretrained
43 | Inception network expects
44 | requires_grad : bool
45 | If true, parameters of the model require gradient. Possibly useful
46 | for finetuning the network
47 | """
48 | super(InceptionV3, self).__init__()
49 |
50 | self.resize_input = resize_input
51 | self.normalize_input = normalize_input
52 | self.output_blocks = sorted(output_blocks)
53 | self.last_needed_block = max(output_blocks)
54 |
55 | assert self.last_needed_block <= 3, \
56 | 'Last possible output block index is 3'
57 |
58 | self.blocks = nn.ModuleList()
59 |
60 | inception = models.inception_v3(pretrained=True)
61 |
62 | # Block 0: input to maxpool1
63 | block0 = [
64 | inception.Conv2d_1a_3x3,
65 | inception.Conv2d_2a_3x3,
66 | inception.Conv2d_2b_3x3,
67 | nn.MaxPool2d(kernel_size=3, stride=2)
68 | ]
69 | self.blocks.append(nn.Sequential(*block0))
70 |
71 | # Block 1: maxpool1 to maxpool2
72 | if self.last_needed_block >= 1:
73 | block1 = [
74 | inception.Conv2d_3b_1x1,
75 | inception.Conv2d_4a_3x3,
76 | nn.MaxPool2d(kernel_size=3, stride=2)
77 | ]
78 | self.blocks.append(nn.Sequential(*block1))
79 |
80 | # Block 2: maxpool2 to aux classifier
81 | if self.last_needed_block >= 2:
82 | block2 = [
83 | inception.Mixed_5b,
84 | inception.Mixed_5c,
85 | inception.Mixed_5d,
86 | inception.Mixed_6a,
87 | inception.Mixed_6b,
88 | inception.Mixed_6c,
89 | inception.Mixed_6d,
90 | inception.Mixed_6e,
91 | ]
92 | self.blocks.append(nn.Sequential(*block2))
93 |
94 | # Block 3: aux classifier to final avgpool
95 | if self.last_needed_block >= 3:
96 | block3 = [
97 | inception.Mixed_7a,
98 | inception.Mixed_7b,
99 | inception.Mixed_7c,
100 | nn.AdaptiveAvgPool2d(output_size=(1, 1))
101 | ]
102 | self.blocks.append(nn.Sequential(*block3))
103 |
104 | for param in self.parameters():
105 | param.requires_grad = requires_grad
106 |
107 | def forward(self, inp):
108 | """Get Inception feature maps
109 |
110 | Parameters
111 | ----------
112 | inp : torch.autograd.Variable
113 | Input tensor of shape Bx3xHxW. Values are expected to be in
114 | range (0, 1)
115 |
116 | Returns
117 | -------
118 | List of torch.autograd.Variable, corresponding to the selected output
119 | block, sorted ascending by index
120 | """
121 | outp = []
122 | x = inp
123 |
124 | if self.resize_input:
125 | x = F.upsample(x, size=(299, 299), mode='bilinear', align_corners=True)
126 |
127 | if self.normalize_input:
128 | x = x.clone()
129 | x[:, 0] = x[:, 0] * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
130 | x[:, 1] = x[:, 1] * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
131 | x[:, 2] = x[:, 2] * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
132 |
133 | for idx, block in enumerate(self.blocks):
134 | x = block(x)
135 | if idx in self.output_blocks:
136 | outp.append(x)
137 |
138 | if idx == self.last_needed_block:
139 | break
140 |
141 | return outp
142 |
--------------------------------------------------------------------------------
/code/src/test.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | import os.path as osp
3 | import time
4 | import random
5 | import argparse
6 | import numpy as np
7 | from PIL import Image
8 | import pprint
9 |
10 | import torch
11 | import torch.nn as nn
12 | import torch.optim as optim
13 | from torch.autograd import Variable
14 | import torch.backends.cudnn as cudnn
15 | from torchvision.utils import save_image,make_grid
16 | from torch.utils.tensorboard import SummaryWriter
17 | import torchvision.transforms as transforms
18 | import torchvision.utils as vutils
19 | from torch.utils.data.distributed import DistributedSampler
20 | import multiprocessing as mp
21 |
22 | ROOT_PATH = osp.abspath(osp.join(osp.dirname(osp.abspath(__file__)), ".."))
23 | sys.path.insert(0, ROOT_PATH)
24 | from lib.utils import mkdir_p,get_rank,merge_args_yaml,get_time_stamp,save_args
25 | from lib.utils import load_netG,load_npz,save_models
26 | from lib.perpare import prepare_dataloaders
27 | from lib.perpare import prepare_models
28 | from lib.modules import test as test
29 |
30 |
31 | def parse_args():
32 | # Training settings
33 | parser = argparse.ArgumentParser(description='Text2Img')
34 | parser.add_argument('--cfg', dest='cfg_file', type=str, default='../cfg/coco.yml',
35 | help='optional config file')
36 | parser.add_argument('--num_workers', type=int, default=4,
37 | help='number of workers(default: {0})'.format(mp.cpu_count() - 1))
38 | parser.add_argument('--stamp', type=str, default='normal',
39 | help='the stamp of model')
40 | parser.add_argument('--pretrained_model_path', type=str, default='model',
41 | help='the model for training')
42 | parser.add_argument('--log_dir', type=str, default='new',
43 | help='file path to log directory')
44 | parser.add_argument('--model', type=str, default='GALIP',
45 | help='the model for training')
46 | parser.add_argument('--state_epoch', type=int, default=100,
47 | help='state epoch')
48 | parser.add_argument('--batch_size', type=int, default=1024,
49 | help='batch size')
50 | parser.add_argument('--train', type=str, default='True',
51 | help='if train model')
52 | parser.add_argument('--mixed_precision', type=str, default='False',
53 | help='if use multi-gpu')
54 | parser.add_argument('--multi_gpus', type=str, default='False',
55 | help='if use multi-gpu')
56 | parser.add_argument('--gpu_id', type=int, default=1,
57 | help='gpu id')
58 | parser.add_argument('--local_rank', default=-1, type=int,
59 | help='node rank for distributed training')
60 | parser.add_argument('--random_sample', action='store_true',default=True,
61 | help='whether to sample the dataset with random sampler')
62 | args = parser.parse_args()
63 | return args
64 |
65 |
66 | def main(args):
67 | time_stamp = get_time_stamp()
68 | stamp = '_'.join([str(args.model),str(args.stamp),str(args.CONFIG_NAME),str(args.imsize),time_stamp])
69 | log_dir = osp.join(ROOT_PATH, 'logs/{0}'.format(osp.join(str(args.CONFIG_NAME), 'train', stamp)))
70 | if (args.multi_gpus==True) and (get_rank() != 0):
71 | None
72 | else:
73 | mkdir_p(osp.join(ROOT_PATH, 'logs'))
74 | # prepare TensorBoard
75 | if (args.multi_gpus==True) and (get_rank() != 0):
76 | writer = None
77 | else:
78 | writer = SummaryWriter(log_dir)
79 | # Build and load the generator
80 | # prepare dataloader, models, data
81 | train_dl, valid_dl ,train_ds, valid_ds, sampler = prepare_dataloaders(args)
82 | CLIP4trn, CLIP4evl, image_encoder, text_encoder, netG, netD, netC = prepare_models(args)
83 | state_path = args.pretrained_model_path
84 | multi_gpus = args.multi_gpus
85 | m1, s1 = load_npz(args.npz_path)
86 | netG = load_netG(netG, state_path, multi_gpus, args.train)
87 |
88 | save_models(netG, netD, netC, 0, args.multi_gpus, './tmp')
89 |
90 | netG.eval()
91 | FID, TI_score = test(valid_dl, text_encoder, netG, CLIP4evl, args.device, m1, s1, -1, -1, \
92 | args.sample_times, args.z_dim, args.batch_size)
93 | if (args.multi_gpus==True) and (get_rank() != 0):
94 | None
95 | else:
96 | print('FID: %.2f, CLIP_Score: %.2f' % (FID, TI_score*100))
97 |
98 |
99 | if __name__ == "__main__":
100 | args = merge_args_yaml(parse_args())
101 | # set seed
102 | if args.manual_seed is None:
103 | args.manual_seed = 100
104 | #args.manualSeed = random.randint(1, 10000)
105 | random.seed(args.manual_seed)
106 | np.random.seed(args.manual_seed)
107 | torch.manual_seed(args.manual_seed)
108 | if args.cuda:
109 | if args.multi_gpus:
110 | torch.cuda.manual_seed_all(args.manual_seed)
111 | torch.distributed.init_process_group(backend="nccl")
112 | local_rank = torch.distributed.get_rank()
113 | torch.cuda.set_device(local_rank)
114 | args.device = torch.device("cuda", local_rank)
115 | args.local_rank = local_rank
116 | else:
117 | torch.cuda.manual_seed_all(args.manual_seed)
118 | torch.cuda.set_device(args.gpu_id)
119 | args.device = torch.device("cuda")
120 | else:
121 | args.device = torch.device('cpu')
122 | main(args)
123 |
124 |
125 |
--------------------------------------------------------------------------------
/code/lib/datasets.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import time
4 | import numpy as np
5 | import pandas as pd
6 | from PIL import Image
7 | import numpy.random as random
8 | if sys.version_info[0] == 2:
9 | import cPickle as pickle
10 | else:
11 | import pickle
12 | import torch
13 | import torch.utils.data as data
14 | from torch.autograd import Variable
15 | import torchvision.transforms as transforms
16 | import clip as clip
17 |
18 |
19 | def get_fix_data(train_dl, test_dl, text_encoder, args):
20 | fixed_image_train, _, _, fixed_sent_train, fixed_word_train, fixed_key_train = get_one_batch_data(train_dl, text_encoder, args)
21 | fixed_image_test, _, _, fixed_sent_test, fixed_word_test, fixed_key_test= get_one_batch_data(test_dl, text_encoder, args)
22 | fixed_image = torch.cat((fixed_image_train, fixed_image_test), dim=0)
23 | fixed_sent = torch.cat((fixed_sent_train, fixed_sent_test), dim=0)
24 | fixed_word = torch.cat((fixed_word_train, fixed_word_test), dim=0)
25 | fixed_noise = torch.randn(fixed_image.size(0), args.z_dim).to(args.device)
26 | return fixed_image, fixed_sent, fixed_word, fixed_noise
27 |
28 |
29 | def get_one_batch_data(dataloader, text_encoder, args):
30 | data = next(iter(dataloader))
31 | imgs, captions, CLIP_tokens, sent_emb, words_embs, keys = prepare_data(data, text_encoder, args.device)
32 | return imgs, captions, CLIP_tokens, sent_emb, words_embs, keys
33 |
34 |
35 | def prepare_data(data, text_encoder, device):
36 | imgs, captions, CLIP_tokens, keys = data
37 | imgs, CLIP_tokens = imgs.to(device), CLIP_tokens.to(device)
38 | sent_emb, words_embs = encode_tokens(text_encoder, CLIP_tokens)
39 | return imgs, captions, CLIP_tokens, sent_emb, words_embs, keys
40 |
41 |
42 | def encode_tokens(text_encoder, caption):
43 | # encode text
44 | with torch.no_grad():
45 | sent_emb,words_embs = text_encoder(caption)
46 | sent_emb,words_embs = sent_emb.detach(), words_embs.detach()
47 | return sent_emb, words_embs
48 |
49 |
50 | def get_imgs(img_path, bbox=None, transform=None, normalize=None):
51 | img = Image.open(img_path).convert('RGB')
52 | width, height = img.size
53 | if bbox is not None:
54 | r = int(np.maximum(bbox[2], bbox[3]) * 0.75)
55 | center_x = int((2 * bbox[0] + bbox[2]) / 2)
56 | center_y = int((2 * bbox[1] + bbox[3]) / 2)
57 | y1 = np.maximum(0, center_y - r)
58 | y2 = np.minimum(height, center_y + r)
59 | x1 = np.maximum(0, center_x - r)
60 | x2 = np.minimum(width, center_x + r)
61 | img = img.crop([x1, y1, x2, y2])
62 | if transform is not None:
63 | img = transform(img)
64 | if normalize is not None:
65 | img = normalize(img)
66 | return img
67 |
68 |
69 | def get_caption(cap_path,clip_info):
70 | eff_captions = []
71 | with open(cap_path, "r") as f:
72 | captions = f.read().encode('utf-8').decode('utf8').split('\n')
73 | for cap in captions:
74 | if len(cap) != 0:
75 | eff_captions.append(cap)
76 | sent_ix = random.randint(0, len(eff_captions))
77 | caption = eff_captions[sent_ix]
78 | tokens = clip.tokenize(caption,truncate=True)
79 | return caption, tokens[0]
80 |
81 |
82 | ################################################################
83 | # Dataset
84 | ################################################################
85 | class TextImgDataset(data.Dataset):
86 | def __init__(self, split, transform=None, args=None):
87 | self.transform = transform
88 | self.clip4text = args.clip4text
89 | self.data_dir = args.data_dir
90 | self.dataset_name = args.dataset_name
91 | self.norm = transforms.Compose([
92 | transforms.ToTensor(),
93 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
94 | ])
95 | self.split=split
96 |
97 | if self.data_dir.find('birds') != -1:
98 | self.bbox = self.load_bbox()
99 | else:
100 | self.bbox = None
101 | self.split_dir = os.path.join(self.data_dir, split)
102 | self.filenames = self.load_filenames(self.data_dir, split)
103 | self.number_example = len(self.filenames)
104 |
105 | def load_bbox(self):
106 | data_dir = self.data_dir
107 | bbox_path = os.path.join(data_dir, 'CUB_200_2011/bounding_boxes.txt')
108 | df_bounding_boxes = pd.read_csv(bbox_path,
109 | delim_whitespace=True,
110 | header=None).astype(int)
111 | #
112 | filepath = os.path.join(data_dir, 'CUB_200_2011/images.txt')
113 | df_filenames = \
114 | pd.read_csv(filepath, delim_whitespace=True, header=None)
115 | filenames = df_filenames[1].tolist()
116 | print('Total filenames: ', len(filenames), filenames[0])
117 | #
118 | filename_bbox = {img_file[:-4]: [] for img_file in filenames}
119 | numImgs = len(filenames)
120 | for i in range(0, numImgs):
121 | # bbox = [x-left, y-top, width, height]
122 | bbox = df_bounding_boxes.iloc[i][1:].tolist()
123 | key = filenames[i][:-4]
124 | filename_bbox[key] = bbox
125 | return filename_bbox
126 |
127 | def load_filenames(self, data_dir, split):
128 | filepath = '%s/%s/filenames.pickle' % (data_dir, split)
129 | if os.path.isfile(filepath):
130 | with open(filepath, 'rb') as f:
131 | filenames = pickle.load(f)
132 | print('Load filenames from: %s (%d)' % (filepath, len(filenames)))
133 | else:
134 | filenames = []
135 | return filenames
136 |
137 | def __getitem__(self, index):
138 | #
139 | key = self.filenames[index]
140 | data_dir = self.data_dir
141 | #
142 | if self.bbox is not None:
143 | bbox = self.bbox[key]
144 | else:
145 | bbox = None
146 | #
147 | if self.dataset_name.lower().find('coco') != -1:
148 | if self.split=='train':
149 | img_name = '%s/images/train2014/jpg/%s.jpg' % (data_dir, key)
150 | text_name = '%s/text/%s.txt' % (data_dir, key)
151 | else:
152 | img_name = '%s/images/val2014/jpg/%s.jpg' % (data_dir, key)
153 | text_name = '%s/text/%s.txt' % (data_dir, key)
154 | elif self.dataset_name.lower().find('cc3m') != -1:
155 | if self.split=='train':
156 | img_name = '%s/images/train/%s.jpg' % (data_dir, key)
157 | text_name = '%s/text/train/%s.txt' % (data_dir, key.split('_')[0])
158 | else:
159 | img_name = '%s/images/test/%s.jpg' % (data_dir, key)
160 | text_name = '%s/text/test/%s.txt' % (data_dir, key.split('_')[0])
161 | elif self.dataset_name.lower().find('cc12m') != -1:
162 | if self.split=='train':
163 | img_name = '%s/images/%s.jpg' % (data_dir, key)
164 | text_name = '%s/text/%s.txt' % (data_dir, key.split('_')[0])
165 | else:
166 | img_name = '%s/images/%s.jpg' % (data_dir, key)
167 | text_name = '%s/text/%s.txt' % (data_dir, key.split('_')[0])
168 | else:
169 | img_name = '%s/CUB_200_2011/images/%s.jpg' % (data_dir, key)
170 | text_name = '%s/text/%s.txt' % (data_dir, key)
171 | #
172 | imgs = get_imgs(img_name, bbox, self.transform, normalize=self.norm)
173 | caps,tokens = get_caption(text_name,self.clip4text)
174 | return imgs, caps, tokens, key
175 |
176 | def __len__(self):
177 | return len(self.filenames)
178 |
179 |
--------------------------------------------------------------------------------
/code/lib/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import errno
4 | import numpy as np
5 | import numpy.random as random
6 | import torch
7 | from torch import distributed as dist
8 | from tqdm import tqdm
9 | import yaml
10 | from easydict import EasyDict as edict
11 | import pprint
12 | import datetime
13 | import dateutil.tz
14 | from PIL import Image
15 |
16 | import importlib
17 | from torchvision.transforms import InterpolationMode
18 | import torch.nn.functional as F
19 | try:
20 | from torchvision.transforms import InterpolationMode
21 | BICUBIC = InterpolationMode.BICUBIC
22 | except ImportError:
23 | BICUBIC = Image.BICUBIC
24 |
25 |
26 | def choose_model(model):
27 | '''choose models
28 | '''
29 | model = importlib.import_module(".%s"%(model), "models")
30 | NetG, NetD, NetC, CLIP_IMG_ENCODER, CLIP_TXT_ENCODER = model.NetG, model.NetD, model.NetC, model.CLIP_IMG_ENCODER, model.CLIP_TXT_ENCODER
31 | return NetG,NetD,NetC,CLIP_IMG_ENCODER, CLIP_TXT_ENCODER
32 |
33 |
34 | def params_count(model):
35 | model_size = np.sum([p.numel() for p in model.parameters()]).item()
36 | return model_size
37 |
38 |
39 | def get_time_stamp():
40 | now = datetime.datetime.now(dateutil.tz.tzlocal())
41 | timestamp = now.strftime('%Y_%m_%d_%H_%M_%S')
42 | return timestamp
43 |
44 |
45 | def mkdir_p(path):
46 | try:
47 | os.makedirs(path)
48 | except OSError as exc: # Python >2.5
49 | if exc.errno == errno.EEXIST and os.path.isdir(path):
50 | pass
51 | else:
52 | raise
53 |
54 |
55 | def load_npz(path):
56 | f = np.load(path)
57 | m, s = f['mu'][:], f['sigma'][:]
58 | f.close()
59 | return m, s
60 |
61 | # config
62 | def load_yaml(filename):
63 | with open(filename, 'r') as f:
64 | cfg = edict(yaml.load(f, Loader=yaml.FullLoader))
65 | return cfg
66 |
67 |
68 | def str2bool_dict(dict):
69 | for key,value in dict.items():
70 | if type(value)==str:
71 | if value.lower() in ('yes','true'):
72 | dict[key] = True
73 | elif value.lower() in ('no','false'):
74 | dict[key] = False
75 | else:
76 | None
77 | return dict
78 |
79 |
80 | def merge_args_yaml(args):
81 | if args.cfg_file is not None:
82 | opt = vars(args)
83 | args = load_yaml(args.cfg_file)
84 | args.update(opt)
85 | args = str2bool_dict(args)
86 | args = edict(args)
87 | return args
88 |
89 |
90 | def save_args(save_path, args):
91 | fp = open(save_path, 'w')
92 | fp.write(yaml.dump(args))
93 | fp.close()
94 |
95 |
96 | def read_txt_file(txt_file):
97 | content = []
98 | with open(txt_file, "r") as f:
99 | for line in f.readlines():
100 | line = line.strip('\n')
101 | content.append(line)
102 | return content
103 |
104 |
105 | def get_rank():
106 | if not dist.is_available():
107 | return 0
108 | if not dist.is_initialized():
109 | return 0
110 | return dist.get_rank()
111 |
112 |
113 | def load_opt_weights(optimizer, weights):
114 | optimizer.load_state_dict(weights)
115 | return optimizer
116 |
117 |
118 | def load_models_opt(netG, netD, netC, optim_G, optim_D, path, multi_gpus):
119 | checkpoint = torch.load(path, map_location=torch.device('cpu'))
120 | netG = load_model_weights(netG, checkpoint['model']['netG'], multi_gpus)
121 | netD = load_model_weights(netD, checkpoint['model']['netD'], multi_gpus)
122 | netC = load_model_weights(netC, checkpoint['model']['netC'], multi_gpus)
123 | optim_G = load_opt_weights(optim_G, checkpoint['optimizers']['optimizer_G'])
124 | optim_D = load_opt_weights(optim_D, checkpoint['optimizers']['optimizer_D'])
125 | return netG, netD, netC, optim_G, optim_D
126 |
127 |
128 | def load_models(netG, netD, netC, path):
129 | checkpoint = torch.load(path, map_location=torch.device('cpu'))
130 | netG = load_model_weights(netG, checkpoint['model']['netG'])
131 | netD = load_model_weights(netD, checkpoint['model']['netD'])
132 | netC = load_model_weights(netC, checkpoint['model']['netC'])
133 | return netG, netD, netC
134 |
135 |
136 | def load_netG(netG, path, multi_gpus, train):
137 | checkpoint = torch.load(path, map_location="cpu")
138 | netG = load_model_weights(netG, checkpoint['model']['netG'], multi_gpus, train)
139 | return netG
140 |
141 |
142 | def load_model_weights(model, weights, multi_gpus, train=True):
143 | if list(weights.keys())[0].find('module')==-1:
144 | pretrained_with_multi_gpu = False
145 | else:
146 | pretrained_with_multi_gpu = True
147 | if (multi_gpus==False) or (train==False):
148 | if pretrained_with_multi_gpu:
149 | state_dict = {
150 | key[7:]: value
151 | for key, value in weights.items()
152 | }
153 | else:
154 | state_dict = weights
155 | else:
156 | state_dict = weights
157 | model.load_state_dict(state_dict)
158 | return model
159 |
160 |
161 | def save_models_opt(netG, netD, netC, optG, optD, epoch, multi_gpus, save_path):
162 | if (multi_gpus==True) and (get_rank() != 0):
163 | None
164 | else:
165 | state = {'model': {'netG': netG.state_dict(), 'netD': netD.state_dict(), 'netC': netC.state_dict()}, \
166 | 'optimizers': {'optimizer_G': optG.state_dict(), 'optimizer_D': optD.state_dict()},\
167 | 'epoch': epoch}
168 | torch.save(state, '%s/state_epoch_%03d.pth' % (save_path, epoch))
169 |
170 |
171 | def save_models(netG, netD, netC, epoch, multi_gpus, save_path):
172 | if (multi_gpus==True) and (get_rank() != 0):
173 | None
174 | else:
175 | state = {'model': {'netG': netG.state_dict(), 'netD': netD.state_dict(), 'netC': netC.state_dict()}}
176 | torch.save(state, '%s/state_epoch_%03d.pth' % (save_path, epoch))
177 |
178 |
179 | def save_checkpoints(netG, netD, netC, optG, optD, scaler_G, scaler_D, epoch, multi_gpus, save_path):
180 | if (multi_gpus==True) and (get_rank() != 0):
181 | None
182 | else:
183 | state = {'model': {'netG': netG.state_dict(), 'netD': netD.state_dict(), 'netC': netC.state_dict()}, \
184 | 'optimizers': {'optimizer_G': optG.state_dict(), 'optimizer_D': optD.state_dict()},\
185 | "scalers": {"scaler_G": scaler_G.state_dict(), "scaler_D": scaler_D.state_dict()},\
186 | 'epoch': epoch}
187 | torch.save(state, '%s/state_epoch_%03d.pth' % (save_path, epoch))
188 |
189 |
190 | def write_to_txt(filename, contents):
191 | fh = open(filename, 'w')
192 | fh.write(contents)
193 | fh.close()
194 |
195 |
196 | def read_txt_file(txt_file):
197 | # text_file: file path
198 | content = []
199 | with open(txt_file, "r") as f:
200 | for line in f.readlines():
201 | line = line.strip('\n')
202 | content.append(line)
203 | return content
204 |
205 |
206 | def save_img(img, path):
207 | im = img.data.cpu().numpy()
208 | # [-1, 1] --> [0, 255]
209 | im = (im + 1.0) * 127.5
210 | im = im.astype(np.uint8)
211 | im = np.transpose(im, (1, 2, 0))
212 | im = Image.fromarray(im)
213 | im.save(path)
214 |
215 |
216 | def transf_to_CLIP_input(inputs):
217 | device = inputs.device
218 | if len(inputs.size()) != 4:
219 | raise ValueError('Expect the (B, C, X, Y) tensor.')
220 | else:
221 | mean = torch.tensor([0.48145466, 0.4578275, 0.40821073])\
222 | .unsqueeze(-1).unsqueeze(-1).unsqueeze(0).to(device)
223 | var = torch.tensor([0.26862954, 0.26130258, 0.27577711])\
224 | .unsqueeze(-1).unsqueeze(-1).unsqueeze(0).to(device)
225 | inputs = F.interpolate(inputs*0.5+0.5, size=(224, 224))
226 | inputs = ((inputs+1)*0.5-mean)/var
227 | return inputs.float()
228 |
229 |
230 | class dummy_context_mgr():
231 | def __enter__(self):
232 | return None
233 |
234 | def __exit__(self, exc_type, exc_value, traceback):
235 | return False
236 |
--------------------------------------------------------------------------------
/code/src/train.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | import os.path as osp
3 | import time
4 | import random
5 | import argparse
6 | import numpy as np
7 | from PIL import Image
8 | import pprint
9 |
10 | import torch
11 | import torch.nn as nn
12 | import torch.optim as optim
13 | from torch.autograd import Variable
14 | import torch.backends.cudnn as cudnn
15 | from torchvision.utils import save_image,make_grid
16 | from torch.utils.tensorboard import SummaryWriter
17 | import torchvision.transforms as transforms
18 | import torchvision.utils as vutils
19 | from torch.utils.data.distributed import DistributedSampler
20 | import multiprocessing as mp
21 |
22 | ROOT_PATH = osp.abspath(osp.join(osp.dirname(osp.abspath(__file__)), ".."))
23 | sys.path.insert(0, ROOT_PATH)
24 | from lib.utils import mkdir_p,get_rank,merge_args_yaml,get_time_stamp,save_args
25 | from lib.utils import load_models_opt,save_models_opt,save_models,load_npz,params_count
26 | from lib.perpare import prepare_dataloaders,prepare_models
27 | from lib.modules import sample_one_batch as sample, test as test, train as train
28 | from lib.datasets import get_fix_data
29 |
30 |
31 | def parse_args():
32 | # Training settings
33 | parser = argparse.ArgumentParser(description='Text2Img')
34 | parser.add_argument('--cfg', dest='cfg_file', type=str, default='../cfg/coco.yml',
35 | help='optional config file')
36 | parser.add_argument('--num_workers', type=int, default=4,
37 | help='number of workers(default: {0})'.format(mp.cpu_count() - 1))
38 | parser.add_argument('--stamp', type=str, default='normal',
39 | help='the stamp of model')
40 | parser.add_argument('--pretrained_model_path', type=str, default='model',
41 | help='the model for training')
42 | parser.add_argument('--log_dir', type=str, default='new',
43 | help='file path to log directory')
44 | parser.add_argument('--model', type=str, default='GALIP',
45 | help='the model for training')
46 | parser.add_argument('--state_epoch', type=int, default=100,
47 | help='state epoch')
48 | parser.add_argument('--batch_size', type=int, default=1024,
49 | help='batch size')
50 | parser.add_argument('--train', type=str, default='True',
51 | help='if train model')
52 | parser.add_argument('--mixed_precision', type=str, default='False',
53 | help='if use multi-gpu')
54 | parser.add_argument('--multi_gpus', type=str, default='False',
55 | help='if use multi-gpu')
56 | parser.add_argument('--gpu_id', type=int, default=1,
57 | help='gpu id')
58 | parser.add_argument('--local_rank', default=-1, type=int,
59 | help='node rank for distributed training')
60 | parser.add_argument('--random_sample', action='store_true',default=True,
61 | help='whether to sample the dataset with random sampler')
62 | args = parser.parse_args()
63 | return args
64 |
65 |
66 | def main(args):
67 | time_stamp = get_time_stamp()
68 | stamp = '_'.join([str(args.model),'nf'+str(args.nf),str(args.stamp),str(args.CONFIG_NAME),str(args.imsize),time_stamp])
69 | args.model_save_file = osp.join(ROOT_PATH, 'saved_models', str(args.CONFIG_NAME), stamp)
70 | log_dir = args.log_dir
71 | if log_dir == 'new':
72 | log_dir = osp.join(ROOT_PATH, 'logs/{0}'.format(osp.join(str(args.CONFIG_NAME), 'train', stamp)))
73 | args.img_save_dir = osp.join(ROOT_PATH, 'imgs/{0}'.format(osp.join(str(args.CONFIG_NAME), 'train', stamp)))
74 | if (args.multi_gpus==True) and (get_rank() != 0):
75 | None
76 | else:
77 | mkdir_p(osp.join(ROOT_PATH, 'logs'))
78 | mkdir_p(args.model_save_file)
79 | mkdir_p(args.img_save_dir)
80 | # prepare TensorBoard
81 | if (args.multi_gpus==True) and (get_rank() != 0):
82 | writer = None
83 | else:
84 | writer = SummaryWriter(log_dir)
85 | # prepare dataloader, models, data
86 | train_dl, valid_dl ,train_ds, valid_ds, sampler = prepare_dataloaders(args)
87 | CLIP4trn, CLIP4evl, image_encoder, text_encoder, netG, netD, netC = prepare_models(args)
88 | print('**************G_paras: ',params_count(netG))
89 | print('**************D_paras: ',params_count(netD)+params_count(netC))
90 | fixed_img, fixed_sent, fixed_words, fixed_z = get_fix_data(train_dl, valid_dl, text_encoder, args)
91 | if (args.multi_gpus==True) and (get_rank() != 0):
92 | None
93 | else:
94 | fixed_grid = make_grid(fixed_img.cpu(), nrow=8, normalize=True)
95 | #writer.add_image('fixed images', fixed_grid, 0)
96 | img_name = 'gt.png'
97 | img_save_path = osp.join(args.img_save_dir, img_name)
98 | vutils.save_image(fixed_img.data, img_save_path, nrow=8, normalize=True)
99 | # prepare optimizer
100 | D_params = list(netD.parameters()) + list(netC.parameters())
101 | optimizerD = torch.optim.Adam(D_params, lr=args.lr_d, betas=(0.0, 0.9))
102 | optimizerG = torch.optim.Adam(netG.parameters(), lr=args.lr_g, betas=(0.0, 0.9))
103 | if args.mixed_precision==True:
104 | scaler_D = torch.cuda.amp.GradScaler(growth_interval=args.growth_interval)
105 | scaler_G = torch.cuda.amp.GradScaler(growth_interval=args.growth_interval)
106 | else:
107 | scaler_D = None
108 | scaler_G = None
109 | m1, s1 = load_npz(args.npz_path)
110 | start_epoch = 1
111 | # load from checkpoint
112 | if args.state_epoch!=1:
113 | start_epoch = args.state_epoch + 1
114 | path = osp.join(args.pretrained_model_path, 'state_epoch_%03d.pth'%(args.state_epoch))
115 | netG, netD, netC, optimizerG, optimizerD = load_models_opt(netG, netD, netC, optimizerG, optimizerD, path, args.multi_gpus)
116 | #netG, netD, netC = load_model(netG, netD, netC, path, args.multi_gpus)
117 | # print args
118 | if (args.multi_gpus==True) and (get_rank() != 0):
119 | None
120 | else:
121 | pprint.pprint(args)
122 | arg_save_path = osp.join(log_dir, 'args.yaml')
123 | save_args(arg_save_path, args)
124 | print("Start Training")
125 | # Start training
126 | test_interval,gen_interval,save_interval = args.test_interval,args.gen_interval,args.save_interval
127 | #torch.cuda.empty_cache()
128 | # start_epoch = 1
129 | for epoch in range(start_epoch, args.max_epoch, 1):
130 | if (args.multi_gpus==True):
131 | sampler.set_epoch(epoch)
132 | start_t = time.time()
133 | # training
134 | args.current_epoch = epoch
135 | torch.cuda.empty_cache()
136 | train(train_dl, netG, netD, netC, text_encoder, image_encoder, optimizerG, optimizerD, scaler_G, scaler_D, args)
137 | torch.cuda.empty_cache()
138 | # save
139 | if epoch%save_interval==0:
140 | save_models_opt(netG, netD, netC, optimizerG, optimizerD, epoch, args.multi_gpus, args.model_save_file)
141 | torch.cuda.empty_cache()
142 | # sample
143 | if epoch%gen_interval==0:
144 | sample(fixed_z, fixed_sent, netG, args.multi_gpus, epoch, args.img_save_dir, writer)
145 | torch.cuda.empty_cache()
146 | # test
147 | if epoch%test_interval==0:
148 | #torch.cuda.empty_cache()
149 | FID, TI_score = test(valid_dl, text_encoder, netG, CLIP4evl, args.device, m1, s1, epoch, args.max_epoch, args.sample_times, args.z_dim, args.batch_size)
150 | torch.cuda.empty_cache()
151 | if (args.multi_gpus==True) and (get_rank() != 0):
152 | None
153 | else:
154 | if epoch%test_interval==0:
155 | writer.add_scalar('FID', FID, epoch)
156 | writer.add_scalar('CLIP_Score', TI_score, epoch)
157 | print('The %d epoch FID: %.2f, CLIP_Score: %.2f' % (epoch,FID,TI_score*100))
158 | end_t = time.time()
159 | print('The epoch %d costs %.2fs'%(epoch, end_t-start_t))
160 | print('*'*40)
161 |
162 |
163 |
164 | if __name__ == "__main__":
165 | args = merge_args_yaml(parse_args())
166 | # set seed
167 | if args.manual_seed is None:
168 | args.manual_seed = 100
169 | #args.manualSeed = random.randint(1, 10000)
170 | random.seed(args.manual_seed)
171 | np.random.seed(args.manual_seed)
172 | torch.manual_seed(args.manual_seed)
173 | if args.cuda:
174 | if args.multi_gpus:
175 | torch.cuda.manual_seed_all(args.manual_seed)
176 | torch.distributed.init_process_group(backend="nccl")
177 | local_rank = torch.distributed.get_rank()
178 | torch.cuda.set_device(local_rank)
179 | args.device = torch.device("cuda", local_rank)
180 | args.local_rank = local_rank
181 | else:
182 | torch.cuda.manual_seed_all(args.manual_seed)
183 | torch.cuda.set_device(args.gpu_id)
184 | args.device = torch.device("cuda")
185 | else:
186 | args.device = torch.device('cpu')
187 | main(args)
188 |
189 |
--------------------------------------------------------------------------------
/code/models/GALIP.py:
--------------------------------------------------------------------------------
1 | ####### 生成器,判别器未采用残差
2 | import torch
3 | import torch.nn as nn
4 | import numpy as np
5 | import torch.nn.functional as F
6 | from collections import OrderedDict
7 | from lib.utils import dummy_context_mgr
8 |
9 |
10 | class CLIP_IMG_ENCODER(nn.Module):
11 | def __init__(self, CLIP):
12 | super(CLIP_IMG_ENCODER, self).__init__()
13 | model = CLIP.visual
14 | # print(model)
15 | self.define_module(model)
16 | for param in self.parameters():
17 | param.requires_grad = False
18 |
19 | def define_module(self, model):
20 | self.conv1 = model.conv1
21 | self.class_embedding = model.class_embedding
22 | self.positional_embedding = model.positional_embedding
23 | self.ln_pre = model.ln_pre
24 | self.transformer = model.transformer
25 | self.ln_post = model.ln_post
26 | self.proj = model.proj
27 |
28 | @property
29 | def dtype(self):
30 | return self.conv1.weight.dtype
31 |
32 | def transf_to_CLIP_input(self,inputs):
33 | device = inputs.device
34 | if len(inputs.size()) != 4:
35 | raise ValueError('Expect the (B, C, X, Y) tensor.')
36 | else:
37 | mean = torch.tensor([0.48145466, 0.4578275, 0.40821073])\
38 | .unsqueeze(-1).unsqueeze(-1).unsqueeze(0).to(device)
39 | var = torch.tensor([0.26862954, 0.26130258, 0.27577711])\
40 | .unsqueeze(-1).unsqueeze(-1).unsqueeze(0).to(device)
41 | inputs = F.interpolate(inputs*0.5+0.5, size=(224, 224))
42 | inputs = ((inputs+1)*0.5-mean)/var
43 | return inputs
44 |
45 | def forward(self, img: torch.Tensor):
46 | x = self.transf_to_CLIP_input(img)
47 | x = x.type(self.dtype)
48 | x = self.conv1(x) # shape = [*, width, grid, grid]
49 | grid = x.size(-1)
50 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
51 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
52 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
53 | x = x + self.positional_embedding.to(x.dtype)
54 | x = self.ln_pre(x)
55 | # NLD -> LND
56 | x = x.permute(1, 0, 2)
57 | # Local features
58 | #selected = [1,4,7,12]
59 | selected = [1,4,8]
60 | local_features = []
61 | for i in range(12):
62 | x = self.transformer.resblocks[i](x)
63 | if i in selected:
64 | local_features.append(x.permute(1, 0, 2)[:, 1:, :].permute(0, 2, 1).reshape(-1, 768, grid, grid).contiguous().type(img.dtype))
65 | x = x.permute(1, 0, 2) # LND -> NLD
66 | x = self.ln_post(x[:, 0, :])
67 | if self.proj is not None:
68 | x = x @ self.proj
69 | return torch.stack(local_features, dim=1), x.type(img.dtype)
70 |
71 |
72 | class CLIP_TXT_ENCODER(nn.Module):
73 | def __init__(self, CLIP):
74 | super(CLIP_TXT_ENCODER, self).__init__()
75 | self.define_module(CLIP)
76 | # print(model)
77 | for param in self.parameters():
78 | param.requires_grad = False
79 |
80 | def define_module(self, CLIP):
81 | self.transformer = CLIP.transformer
82 | self.vocab_size = CLIP.vocab_size
83 | self.token_embedding = CLIP.token_embedding
84 | self.positional_embedding = CLIP.positional_embedding
85 | self.ln_final = CLIP.ln_final
86 | self.text_projection = CLIP.text_projection
87 |
88 | @property
89 | def dtype(self):
90 | return self.transformer.resblocks[0].mlp.c_fc.weight.dtype
91 |
92 | def forward(self, text):
93 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
94 | x = x + self.positional_embedding.type(self.dtype)
95 | x = x.permute(1, 0, 2) # NLD -> LND
96 | x = self.transformer(x)
97 | x = x.permute(1, 0, 2) # LND -> NLD
98 | x = self.ln_final(x).type(self.dtype)
99 | # x.shape = [batch_size, n_ctx, transformer.width]
100 | # take features from the eot embedding (eot_token is the highest number in each sequence)
101 | sent_emb = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
102 | return sent_emb, x
103 |
104 |
105 | class CLIP_Mapper(nn.Module):
106 | def __init__(self, CLIP):
107 | super(CLIP_Mapper, self).__init__()
108 | model = CLIP.visual
109 | # print(model)
110 | self.define_module(model)
111 | for param in model.parameters():
112 | param.requires_grad = False
113 |
114 | def define_module(self, model):
115 | self.conv1 = model.conv1
116 | self.class_embedding = model.class_embedding
117 | self.positional_embedding = model.positional_embedding
118 | self.ln_pre = model.ln_pre
119 | self.transformer = model.transformer
120 |
121 | @property
122 | def dtype(self):
123 | return self.conv1.weight.dtype
124 |
125 | def forward(self, img: torch.Tensor, prompts: torch.Tensor):
126 | x = img.type(self.dtype)
127 | prompts = prompts.type(self.dtype)
128 | grid = x.size(-1)
129 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
130 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
131 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)
132 | # shape = [*, grid ** 2 + 1, width]
133 | x = x + self.positional_embedding.to(x.dtype)
134 | x = self.ln_pre(x)
135 | # NLD -> LND
136 | x = x.permute(1, 0, 2)
137 | # Local features
138 | selected = [1,2,3,4,5,6,7,8]
139 | begin, end = 0, 12
140 | prompt_idx = 0
141 | for i in range(begin, end):
142 | if i in selected:
143 | prompt = prompts[:,prompt_idx,:].unsqueeze(0)
144 | prompt_idx = prompt_idx+1
145 | x = torch.cat((x,prompt), dim=0)
146 | x = self.transformer.resblocks[i](x)
147 | x = x[:-1,:,:]
148 | else:
149 | x = self.transformer.resblocks[i](x)
150 | return x.permute(1, 0, 2)[:, 1:, :].permute(0, 2, 1).reshape(-1, 768, grid, grid).contiguous().type(img.dtype)
151 |
152 |
153 | class CLIP_Adapter(nn.Module):
154 | def __init__(self, in_ch, mid_ch, out_ch, G_ch, CLIP_ch, cond_dim, k, s, p, map_num, CLIP):
155 | super(CLIP_Adapter, self).__init__()
156 | self.CLIP_ch = CLIP_ch
157 | self.FBlocks = nn.ModuleList([])
158 | self.FBlocks.append(M_Block(in_ch, mid_ch, out_ch, cond_dim, k, s, p))
159 | for i in range(map_num-1):
160 | self.FBlocks.append(M_Block(out_ch, mid_ch, out_ch, cond_dim, k, s, p))
161 | self.conv_fuse = nn.Conv2d(out_ch, CLIP_ch, 5, 1, 2)
162 | self.CLIP_ViT = CLIP_Mapper(CLIP)
163 | self.conv = nn.Conv2d(768, G_ch, 5, 1, 2)
164 | #
165 | self.fc_prompt = nn.Linear(cond_dim, CLIP_ch*8)
166 |
167 | def forward(self,out,c):
168 | prompts = self.fc_prompt(c).view(c.size(0),-1,self.CLIP_ch)
169 | for FBlock in self.FBlocks:
170 | out = FBlock(out,c)
171 | fuse_feat = self.conv_fuse(out)
172 | map_feat = self.CLIP_ViT(fuse_feat,prompts)
173 | return self.conv(fuse_feat+0.1*map_feat)
174 |
175 |
176 | class NetG(nn.Module):
177 | def __init__(self, ngf, nz, cond_dim, imsize, ch_size, mixed_precision, CLIP):
178 | super(NetG, self).__init__()
179 | self.ngf = ngf
180 | self.mixed_precision = mixed_precision
181 | # build CLIP Mapper
182 | self.code_sz, self.code_ch, self.mid_ch = 7, 64, 32
183 | self.CLIP_ch = 768
184 | self.fc_code = nn.Linear(nz, self.code_sz*self.code_sz*self.code_ch)
185 | self.mapping = CLIP_Adapter(self.code_ch, self.mid_ch, self.code_ch, ngf*8, self.CLIP_ch, cond_dim+nz, 3, 1, 1, 4, CLIP)
186 | # build GBlocks
187 | self.GBlocks = nn.ModuleList([])
188 | in_out_pairs = list(get_G_in_out_chs(ngf, imsize))
189 | imsize = 4
190 | for idx, (in_ch, out_ch) in enumerate(in_out_pairs):
191 | if idx<(len(in_out_pairs)-1):
192 | imsize = imsize*2
193 | else:
194 | imsize = 224
195 | self.GBlocks.append(G_Block(cond_dim+nz, in_ch, out_ch, imsize))
196 | # to RGB image
197 | self.to_rgb = nn.Sequential(
198 | nn.LeakyReLU(0.2,inplace=True),
199 | nn.Conv2d(out_ch, ch_size, 3, 1, 1),
200 | #nn.Tanh(),
201 | )
202 |
203 | def forward(self, noise, c, eval=False): # x=noise, c=ent_emb
204 | with torch.cuda.amp.autocast() if self.mixed_precision and not eval else dummy_context_mgr() as mp:
205 | cond = torch.cat((noise, c), dim=1)
206 | out = self.mapping(self.fc_code(noise).view(noise.size(0), self.code_ch, self.code_sz, self.code_sz), cond)
207 | # fuse text and visual features
208 | for GBlock in self.GBlocks:
209 | out = GBlock(out, cond)
210 | # convert to RGB image
211 | out = self.to_rgb(out)
212 | return out
213 |
214 |
215 | # 定义鉴别器网络D
216 | class NetD(nn.Module):
217 | def __init__(self, ndf, imsize, ch_size, mixed_precision):
218 | super(NetD, self).__init__()
219 | self.mixed_precision = mixed_precision
220 | self.DBlocks = nn.ModuleList([
221 | D_Block(768, 768, 3, 1, 1, res=True, CLIP_feat=True),
222 | D_Block(768, 768, 3, 1, 1, res=True, CLIP_feat=True),
223 | ])
224 | self.main = D_Block(768, 512, 3, 1, 1, res=True, CLIP_feat=False)
225 |
226 | def forward(self, h):
227 | with torch.cuda.amp.autocast() if self.mixed_precision else dummy_context_mgr() as mpc:
228 | out = h[:,0]
229 | for idx in range(len(self.DBlocks)):
230 | out = self.DBlocks[idx](out, h[:,idx+1])
231 | out = self.main(out)
232 | return out
233 |
234 |
235 | class NetC(nn.Module):
236 | def __init__(self, ndf, cond_dim, mixed_precision):
237 | super(NetC, self).__init__()
238 | self.cond_dim = cond_dim
239 | self.mixed_precision = mixed_precision
240 | self.joint_conv = nn.Sequential(
241 | nn.Conv2d(512+512, 128, 4, 1, 0, bias=False),
242 | nn.LeakyReLU(0.2, inplace=True),
243 | nn.Conv2d(128, 1, 4, 1, 0, bias=False),
244 | )
245 |
246 | def forward(self, out, cond):
247 | with torch.cuda.amp.autocast() if self.mixed_precision else dummy_context_mgr() as mpc:
248 | cond = cond.view(-1, self.cond_dim, 1, 1)
249 | cond = cond.repeat(1, 1, 7, 7)
250 | h_c_code = torch.cat((out, cond), 1)
251 | out = self.joint_conv(h_c_code)
252 | return out
253 |
254 |
255 | class M_Block(nn.Module):
256 | def __init__(self, in_ch, mid_ch, out_ch, cond_dim, k, s, p):
257 | super(M_Block, self).__init__()
258 | self.conv1 = nn.Conv2d(in_ch, mid_ch, k, s, p)
259 | self.fuse1 = DFBLK(cond_dim, mid_ch)
260 | self.conv2 = nn.Conv2d(mid_ch, out_ch, k, s, p)
261 | self.fuse2 = DFBLK(cond_dim, out_ch)
262 | self.learnable_sc = in_ch != out_ch
263 | if self.learnable_sc:
264 | self.c_sc = nn.Conv2d(in_ch, out_ch, 1, stride=1, padding=0)
265 |
266 | def shortcut(self, x):
267 | if self.learnable_sc:
268 | x = self.c_sc(x)
269 | return x
270 |
271 | def residual(self, h, text):
272 | h = self.conv1(h)
273 | h = self.fuse1(h, text)
274 | h = self.conv2(h)
275 | h = self.fuse2(h, text)
276 | return h
277 |
278 | def forward(self, h, c):
279 | return self.shortcut(h) + self.residual(h, c)
280 |
281 |
282 | class G_Block(nn.Module):
283 | def __init__(self, cond_dim, in_ch, out_ch, imsize):
284 | super(G_Block, self).__init__()
285 | self.imsize = imsize
286 | self.learnable_sc = in_ch != out_ch
287 | self.c1 = nn.Conv2d(in_ch, out_ch, 3, 1, 1)
288 | self.c2 = nn.Conv2d(out_ch, out_ch, 3, 1, 1)
289 | self.fuse1 = DFBLK(cond_dim, in_ch)
290 | self.fuse2 = DFBLK(cond_dim, out_ch)
291 | if self.learnable_sc:
292 | self.c_sc = nn.Conv2d(in_ch,out_ch, 1, stride=1, padding=0)
293 |
294 | def shortcut(self, x):
295 | if self.learnable_sc:
296 | x = self.c_sc(x)
297 | return x
298 |
299 | def residual(self, h, y):
300 | h = self.fuse1(h, y)
301 | h = self.c1(h)
302 | h = self.fuse2(h, y)
303 | h = self.c2(h)
304 | return h
305 |
306 | def forward(self, h, y):
307 | h = F.interpolate(h, size=(self.imsize, self.imsize))
308 | return self.shortcut(h) + self.residual(h, y)
309 |
310 |
311 | class D_Block(nn.Module):
312 | def __init__(self, fin, fout, k, s, p, res, CLIP_feat):
313 | super(D_Block, self).__init__()
314 | self.res, self.CLIP_feat = res, CLIP_feat
315 | self.learned_shortcut = (fin != fout)
316 | self.conv_r = nn.Sequential(
317 | nn.Conv2d(fin, fout, k, s, p, bias=False),
318 | nn.LeakyReLU(0.2, inplace=True),
319 | nn.Conv2d(fout, fout, k, s, p, bias=False),
320 | nn.LeakyReLU(0.2, inplace=True),
321 | )
322 | self.conv_s = nn.Conv2d(fin, fout, 1, stride=1, padding=0)
323 | if self.res==True:
324 | self.gamma = nn.Parameter(torch.zeros(1))
325 | if self.CLIP_feat==True:
326 | self.beta = nn.Parameter(torch.zeros(1))
327 |
328 | def forward(self, x, CLIP_feat=None):
329 | res = self.conv_r(x)
330 | if self.learned_shortcut:
331 | x = self.conv_s(x)
332 | if (self.res==True)and(self.CLIP_feat==True):
333 | return x + self.gamma*res + self.beta*CLIP_feat
334 | elif (self.res==True)and(self.CLIP_feat!=True):
335 | return x + self.gamma*res
336 | elif (self.res!=True)and(self.CLIP_feat==True):
337 | return x + self.beta*CLIP_feat
338 | else:
339 | return x
340 |
341 |
342 | class DFBLK(nn.Module):
343 | def __init__(self, cond_dim, in_ch):
344 | super(DFBLK, self).__init__()
345 | self.affine0 = Affine(cond_dim, in_ch)
346 | self.affine1 = Affine(cond_dim, in_ch)
347 |
348 | def forward(self, x, y=None):
349 | h = self.affine0(x, y)
350 | h = nn.LeakyReLU(0.2,inplace=True)(h)
351 | h = self.affine1(h, y)
352 | h = nn.LeakyReLU(0.2,inplace=True)(h)
353 | return h
354 |
355 |
356 | class QuickGELU(nn.Module):
357 | def forward(self, x: torch.Tensor):
358 | return x * torch.sigmoid(1.702 * x)
359 |
360 |
361 | class Affine(nn.Module):
362 | def __init__(self, cond_dim, num_features):
363 | super(Affine, self).__init__()
364 |
365 | self.fc_gamma = nn.Sequential(OrderedDict([
366 | ('linear1',nn.Linear(cond_dim, num_features)),
367 | ('relu1',nn.ReLU(inplace=True)),
368 | ('linear2',nn.Linear(num_features, num_features)),
369 | ]))
370 | self.fc_beta = nn.Sequential(OrderedDict([
371 | ('linear1',nn.Linear(cond_dim, num_features)),
372 | ('relu1',nn.ReLU(inplace=True)),
373 | ('linear2',nn.Linear(num_features, num_features)),
374 | ]))
375 | self._initialize()
376 |
377 | def _initialize(self):
378 | nn.init.zeros_(self.fc_gamma.linear2.weight.data)
379 | nn.init.ones_(self.fc_gamma.linear2.bias.data)
380 | nn.init.zeros_(self.fc_beta.linear2.weight.data)
381 | nn.init.zeros_(self.fc_beta.linear2.bias.data)
382 |
383 | def forward(self, x, y=None):
384 | weight = self.fc_gamma(y)
385 | bias = self.fc_beta(y)
386 |
387 | if weight.dim() == 1:
388 | weight = weight.unsqueeze(0)
389 | if bias.dim() == 1:
390 | bias = bias.unsqueeze(0)
391 |
392 | size = x.size()
393 | weight = weight.unsqueeze(-1).unsqueeze(-1).expand(size)
394 | bias = bias.unsqueeze(-1).unsqueeze(-1).expand(size)
395 | return weight * x + bias
396 |
397 |
398 | def get_G_in_out_chs(nf, imsize):
399 | layer_num = int(np.log2(imsize))-1
400 | channel_nums = [nf*min(2**idx, 8) for idx in range(layer_num)]
401 | channel_nums = channel_nums[::-1]
402 | in_out_pairs = zip(channel_nums[:-1], channel_nums[1:])
403 | return in_out_pairs
404 |
405 |
406 | def get_D_in_out_chs(nf, imsize):
407 | layer_num = int(np.log2(imsize))-1
408 | channel_nums = [nf*min(2**idx, 8) for idx in range(layer_num)]
409 | in_out_pairs = zip(channel_nums[:-1], channel_nums[1:])
410 | return in_out_pairs
411 |
--------------------------------------------------------------------------------
/code/lib/modules.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | from pyexpat import features
3 | import os.path as osp
4 | import time
5 | import random
6 | import datetime
7 | import argparse
8 | from scipy import linalg
9 | import numpy as np
10 | from PIL import Image
11 | from tqdm import tqdm, trange
12 | from torch.autograd import Variable
13 | import torch
14 | import torch.nn as nn
15 | import torch.nn.functional as F
16 | import torch.optim as optim
17 | import torch.backends.cudnn as cudnn
18 | import torchvision.transforms as transforms
19 | import torchvision.utils as vutils
20 | from torchvision.utils import make_grid
21 | from lib.utils import transf_to_CLIP_input, dummy_context_mgr
22 | from lib.utils import mkdir_p, get_rank
23 | from lib.datasets import prepare_data
24 |
25 | from models.inception import InceptionV3
26 | from torch.nn.functional import adaptive_avg_pool2d
27 | import torch.distributed as dist
28 |
29 |
30 | ############ GAN ############
31 | def train(dataloader, netG, netD, netC, text_encoder, image_encoder, optimizerG, optimizerD, scaler_G, scaler_D, args):
32 | batch_size = args.batch_size
33 | device = args.device
34 | epoch = args.current_epoch
35 | max_epoch = args.max_epoch
36 | z_dim = args.z_dim
37 | netG, netD, netC, image_encoder = netG.train(), netD.train(), netC.train(), image_encoder.train()
38 | if (args.multi_gpus==True) and (get_rank() != 0):
39 | None
40 | else:
41 | loop = tqdm(total=len(dataloader))
42 | for step, data in enumerate(dataloader, 0):
43 | ##############
44 | # Train D
45 | ##############
46 | optimizerD.zero_grad()
47 | with torch.cuda.amp.autocast() if args.mixed_precision else dummy_context_mgr() as mpc:
48 | # prepare_data
49 | real, captions, CLIP_tokens, sent_emb, words_embs, keys = prepare_data(data, text_encoder, device)
50 | real = real.requires_grad_()
51 | sent_emb = sent_emb.requires_grad_()
52 | words_embs = words_embs.requires_grad_()
53 | # predict real
54 | CLIP_real,real_emb = image_encoder(real)
55 | real_feats = netD(CLIP_real)
56 | pred_real, errD_real = predict_loss(netC, real_feats, sent_emb, negtive=False)
57 | # predict mismatch
58 | mis_sent_emb = torch.cat((sent_emb[1:], sent_emb[0:1]), dim=0).detach()
59 | _, errD_mis = predict_loss(netC, real_feats, mis_sent_emb, negtive=True)
60 | # synthesize fake images
61 | noise = torch.randn(batch_size, z_dim).to(device)
62 | fake = netG(noise, sent_emb)
63 | CLIP_fake, fake_emb = image_encoder(fake)
64 | fake_feats = netD(CLIP_fake.detach())
65 | _, errD_fake = predict_loss(netC, fake_feats, sent_emb, negtive=True)
66 | # MA-GP
67 | if args.mixed_precision:
68 | errD_MAGP = MA_GP_MP(CLIP_real, sent_emb, pred_real, scaler_D)
69 | else:
70 | errD_MAGP = MA_GP_FP32(CLIP_real, sent_emb, pred_real)
71 | # whole D loss
72 | with torch.cuda.amp.autocast() if args.mixed_precision else dummy_context_mgr() as mpc:
73 | errD = errD_real + (errD_fake + errD_mis)/2.0 + errD_MAGP
74 | # update D
75 | if args.mixed_precision:
76 | scaler_D.scale(errD).backward()
77 | scaler_D.step(optimizerD)
78 | scaler_D.update()
79 | if scaler_D.get_scale() [0, 255]
204 | im = (im + 1.0) * 127.5
205 | im = im.astype(np.uint8)
206 | im = np.transpose(im, (1, 2, 0))
207 | im = Image.fromarray(im)
208 | ######################################################
209 | # (3) Save fake images
210 | ######################################################
211 | if multi_gpus==True:
212 | single_img_name = 'batch_%04d.png'%(j)
213 | single_img_save_dir = osp.join(save_dir, 'single', str('gpu%d'%(get_rank())), 'step%04d'%(step))
214 | single_img_save_name = osp.join(single_img_save_dir, single_img_name)
215 | else:
216 | single_img_name = 'step_%04d.png'%(step)
217 | single_img_save_dir = osp.join(save_dir, 'single', 'step%04d'%(step))
218 | single_img_save_name = osp.join(single_img_save_dir, single_img_name)
219 | mkdir_p(single_img_save_dir)
220 | im.save(single_img_save_name)
221 | if (multi_gpus==True) and (get_rank() != 0):
222 | None
223 | else:
224 | print('Step: %d' % (step))
225 |
226 |
227 | def calculate_FID_CLIP_sim(dataloader, text_encoder, netG, CLIP, device, m1, s1, epoch, max_epoch, times, z_dim, batch_size):
228 | """ Calculates the FID """
229 | clip_cos = torch.FloatTensor([0.0]).to(device)
230 | # prepare Inception V3
231 | dims = 2048
232 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
233 | model = InceptionV3([block_idx])
234 | model.to(device)
235 | model.eval()
236 | netG.eval()
237 | norm = transforms.Compose([
238 | transforms.Normalize((-1, -1, -1), (2, 2, 2)),
239 | transforms.Resize((299, 299)),
240 | ])
241 | n_gpu = dist.get_world_size()
242 | dl_length = dataloader.__len__()
243 | imgs_num = dl_length * n_gpu * batch_size * times
244 | pred_arr = np.empty((imgs_num, dims))
245 | if (n_gpu!=1) and (get_rank() != 0):
246 | None
247 | else:
248 | loop = tqdm(total=int(dl_length*times))
249 | for time in range(times):
250 | for i, data in enumerate(dataloader):
251 | start = i * batch_size * n_gpu + time * dl_length * n_gpu * batch_size
252 | end = start + batch_size * n_gpu
253 | ######################################################
254 | # (1) Prepare_data
255 | ######################################################
256 | imgs, captions, CLIP_tokens, sent_emb, words_embs, keys = prepare_data(data, text_encoder, device)
257 | ######################################################
258 | # (2) Generate fake images
259 | ######################################################
260 | batch_size = sent_emb.size(0)
261 | netG.eval()
262 | with torch.no_grad():
263 | noise = torch.randn(batch_size, z_dim).to(device)
264 | fake_imgs = netG(noise,sent_emb,eval=True).float()
265 | # norm_ip(fake_imgs, -1, 1)
266 | fake_imgs = torch.clamp(fake_imgs, -1., 1.)
267 | fake_imgs = torch.nan_to_num(fake_imgs, nan=-1.0, posinf=1.0, neginf=-1.0)
268 | clip_sim = calc_clip_sim(CLIP, fake_imgs, CLIP_tokens, device)
269 | clip_cos = clip_cos + clip_sim
270 | fake = norm(fake_imgs)
271 | pred = model(fake)[0]
272 | if pred.shape[2] != 1 or pred.shape[3] != 1:
273 | pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
274 | # concat pred from multi GPUs
275 | output = list(torch.empty_like(pred) for _ in range(n_gpu))
276 | dist.all_gather(output, pred)
277 | pred_all = torch.cat(output, dim=0).squeeze(-1).squeeze(-1)
278 | pred_arr[start:end] = pred_all.cpu().data.numpy()
279 | # update loop information
280 | if (n_gpu!=1) and (get_rank() != 0):
281 | None
282 | else:
283 | loop.update(1)
284 | if epoch==-1:
285 | loop.set_description('Evaluating]')
286 | else:
287 | loop.set_description(f'Eval Epoch [{epoch}/{max_epoch}]')
288 | loop.set_postfix()
289 | if (n_gpu!=1) and (get_rank() != 0):
290 | None
291 | else:
292 | loop.close()
293 | # CLIP-score
294 | CLIP_score_gather = list(torch.empty_like(clip_cos) for _ in range(n_gpu))
295 | dist.all_gather(CLIP_score_gather, clip_cos)
296 | clip_score = torch.cat(CLIP_score_gather, dim=0).mean().item()/(dl_length*times)
297 | # FID
298 | m2 = np.mean(pred_arr, axis=0)
299 | s2 = np.cov(pred_arr, rowvar=False)
300 | fid_value = calculate_frechet_distance(m1, s1, m2, s2)
301 | return fid_value,clip_score
302 |
303 |
304 | def calc_clip_sim(clip, fake, caps_clip, device):
305 | ''' calculate cosine similarity between fake and text features,
306 | '''
307 | # Calculate features
308 | fake = transf_to_CLIP_input(fake)
309 | fake_features = clip.encode_image(fake)
310 | text_features = clip.encode_text(caps_clip)
311 | text_img_sim = torch.cosine_similarity(fake_features, text_features).mean()
312 | return text_img_sim
313 |
314 |
315 | def sample_one_batch(noise, sent, netG, multi_gpus, epoch, img_save_dir, writer):
316 | if (multi_gpus==True) and (get_rank() != 0):
317 | None
318 | else:
319 | netG.eval()
320 | with torch.no_grad():
321 | B = noise.size(0)
322 | fixed_results_train = generate_samples(noise[:B//2], sent[:B//2], netG).cpu()
323 | torch.cuda.empty_cache()
324 | fixed_results_test = generate_samples(noise[B//2:], sent[B//2:], netG).cpu()
325 | torch.cuda.empty_cache()
326 | fixed_results = torch.cat((fixed_results_train, fixed_results_test), dim=0)
327 | img_name = 'samples_epoch_%03d.png'%(epoch)
328 | img_save_path = osp.join(img_save_dir, img_name)
329 | vutils.save_image(fixed_results.data, img_save_path, nrow=8, value_range=(-1, 1), normalize=True)
330 |
331 |
332 | def generate_samples(noise, caption, model):
333 | with torch.no_grad():
334 | fake = model(noise, caption, eval=True)
335 | return fake
336 |
337 |
338 | def predict_loss(predictor, img_feature, text_feature, negtive):
339 | output = predictor(img_feature, text_feature)
340 | err = hinge_loss(output, negtive)
341 | return output,err
342 |
343 |
344 | def hinge_loss(output, negtive):
345 | if negtive==False:
346 | err = torch.mean(F.relu(1. - output))
347 | else:
348 | err = torch.mean(F.relu(1. + output))
349 | return err
350 |
351 |
352 | def logit_loss(output, negtive):
353 | batch_size = output.size(0)
354 | real_labels = torch.FloatTensor(batch_size,1).fill_(1).to(output.device)
355 | fake_labels = torch.FloatTensor(batch_size,1).fill_(0).to(output.device)
356 | output = nn.Sigmoid()(output)
357 | if negtive==False:
358 | err = nn.BCELoss()(output, real_labels)
359 | else:
360 | err = nn.BCELoss()(output, fake_labels)
361 | return err
362 |
363 |
364 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
365 | mu1 = np.atleast_1d(mu1)
366 | mu2 = np.atleast_1d(mu2)
367 |
368 | sigma1 = np.atleast_2d(sigma1)
369 | sigma2 = np.atleast_2d(sigma2)
370 |
371 | assert mu1.shape == mu2.shape, \
372 | 'Training and test mean vectors have different lengths'
373 | assert sigma1.shape == sigma2.shape, \
374 | 'Training and test covariances have different dimensions'
375 |
376 | diff = mu1 - mu2
377 | '''
378 | print('&'*20)
379 | print(sigma1)#, sigma1.type())
380 | print('&'*20)
381 | print(sigma2)#, sigma2.type())
382 | '''
383 | # Product might be almost singular
384 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
385 | if not np.isfinite(covmean).all():
386 | msg = ('fid calculation produces singular product; '
387 | 'adding %s to diagonal of cov estimates') % eps
388 | print(msg)
389 | offset = np.eye(sigma1.shape[0]) * eps
390 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
391 |
392 | # Numerical error might give slight imaginary component
393 | if np.iscomplexobj(covmean):
394 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
395 | m = np.max(np.abs(covmean.imag))
396 | raise ValueError('Imaginary component {}'.format(m))
397 | covmean = covmean.real
398 |
399 | tr_covmean = np.trace(covmean)
400 |
401 | return (diff.dot(diff) + np.trace(sigma1) +
402 | np.trace(sigma2) - 2 * tr_covmean)
403 |
--------------------------------------------------------------------------------