├── data └── readme.txt ├── code ├── src │ ├── readme.txt │ ├── sample.ipynb │ ├── test.py │ └── train.py ├── models │ ├── __init__.py │ ├── inception.py │ └── GALIP.py ├── cfg │ ├── coco.yml │ └── birds.yml ├── scripts │ ├── test.sh │ └── train.sh └── lib │ ├── perpare.py │ ├── datasets.py │ ├── utils.py │ └── modules.py ├── logo.jpeg ├── results.jpg ├── requirements.txt ├── LICENSE ├── .gitignore └── README.md /data/readme.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /code/src/readme.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /code/models/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /logo.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tobran/GALIP/HEAD/logo.jpeg -------------------------------------------------------------------------------- /results.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tobran/GALIP/HEAD/results.jpg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm 2 | numpy 3 | torch==1.11.0 4 | torchvision 5 | easydict 6 | pandas 7 | jupyter 8 | pyyaml 9 | ipykernel 10 | scipy 11 | tensorboard 12 | ftfy 13 | regex 14 | tqdm 15 | -------------------------------------------------------------------------------- /code/cfg/coco.yml: -------------------------------------------------------------------------------- 1 | CONFIG_NAME: 'coco' 2 | dataset_name: 'coco' 3 | data_dir: '../data/coco' 4 | 5 | imsize: 256 6 | z_dim: 100 7 | cond_dim: 512 8 | manual_seed: 100 9 | cuda: True 10 | 11 | clip4evl: {'src':"clip", 'type':'ViT-B/32'} 12 | clip4trn: {'src':"clip", 'type':'ViT-B/32'} 13 | clip4text: {'src':"clip", 'type':'ViT-B/32'} 14 | 15 | stamp: 'normal' 16 | state_epoch: 0 17 | max_epoch: 3005 18 | batch_size: 16 19 | gpu_id: 0 20 | nf: 64 21 | ch_size: 3 22 | 23 | scaler_min: 64 24 | growth_interval: 2000 25 | lr_g: 0.0001 26 | lr_d: 0.0004 27 | sim_w: 4.0 28 | 29 | gen_interval: 1 #1 30 | test_interval: 5 #5 31 | save_interval: 5 32 | 33 | sample_times: 1 34 | npz_path: '../data/coco/npz/coco_val256_FIDK0.npz' 35 | log_dir: 'new' 36 | -------------------------------------------------------------------------------- /code/cfg/birds.yml: -------------------------------------------------------------------------------- 1 | CONFIG_NAME: 'bird' 2 | dataset_name: 'birds' 3 | data_dir: '../data/birds' 4 | 5 | imsize: 256 6 | z_dim: 100 7 | cond_dim: 512 8 | manual_seed: 100 9 | cuda: True 10 | 11 | clip4evl: {'src':"clip", 'type':'ViT-B/32'} 12 | clip4trn: {'src':"clip", 'type':'ViT-B/32'} 13 | clip4text: {'src':"clip", 'type':'ViT-B/32'} 14 | 15 | stamp: 'normal' 16 | state_epoch: 0 17 | max_epoch: 1502 18 | batch_size: 16 19 | gpu_id: 0 20 | nf: 64 21 | ch_size: 3 22 | 23 | scaler_min: 64 24 | growth_interval: 2000 25 | lr_g: 0.0001 26 | lr_d: 0.0004 27 | sim_w: 4.0 28 | 29 | gen_interval: 5 #1 30 | test_interval: 20 #5 31 | save_interval: 20 32 | 33 | sample_times: 12 34 | npz_path: '../data/birds/npz/bird_val256_FIDK0.npz' 35 | log_dir: 'new' 36 | -------------------------------------------------------------------------------- /code/scripts/test.sh: -------------------------------------------------------------------------------- 1 | cfg=$1 2 | batch_size=64 3 | 4 | pretrained_model='./saved_models/data/model_save_file/xxx.pth' 5 | multi_gpus=True 6 | mixed_precision=True 7 | 8 | nodes=1 9 | num_workers=8 10 | master_port=11277 11 | stamp=gpu${nodes}MP_${mixed_precision} 12 | 13 | CUDA_VISIBLE_DEVICES=0 python -m torch.distributed.launch --nproc_per_node=$nodes --master_port=$master_port src/test.py \ 14 | --stamp $stamp \ 15 | --cfg $cfg \ 16 | --mixed_precision $mixed_precision \ 17 | --batch_size $batch_size \ 18 | --num_workers $num_workers \ 19 | --multi_gpus $multi_gpus \ 20 | --pretrained_model_path $pretrained_model \ 21 | -------------------------------------------------------------------------------- /code/scripts/train.sh: -------------------------------------------------------------------------------- 1 | cfg=$1 2 | batch_size=64 3 | 4 | state_epoch=1 5 | pretrained_model_path='./saved_models/data/model_save_file' 6 | log_dir='new' 7 | 8 | multi_gpus=True 9 | mixed_precision=True 10 | 11 | nodes=8 12 | num_workers=8 13 | master_port=11266 14 | stamp=gpu${nodes}MP_${mixed_precision} 15 | 16 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node=$nodes --master_port=$master_port src/train.py \ 17 | --stamp $stamp \ 18 | --cfg $cfg \ 19 | --mixed_precision $mixed_precision \ 20 | --log_dir $log_dir \ 21 | --batch_size $batch_size \ 22 | --state_epoch $state_epoch \ 23 | --num_workers $num_workers \ 24 | --multi_gpus $multi_gpus \ 25 | --pretrained_model_path $pretrained_model_path \ 26 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 MingTao 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib64/ 18 | parts/ 19 | sdist/ 20 | var/ 21 | wheels/ 22 | pip-wheel-metadata/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | db.sqlite3-journal 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # IPython 80 | profile_default/ 81 | ipython_config.py 82 | 83 | # pyenv 84 | .python-version 85 | 86 | # pipenv 87 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 88 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 89 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 90 | # install all needed dependencies. 91 | #Pipfile.lock 92 | 93 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 94 | __pypackages__/ 95 | 96 | # Celery stuff 97 | celerybeat-schedule 98 | celerybeat.pid 99 | 100 | # SageMath parsed files 101 | *.sage.py 102 | 103 | # Environments 104 | .env 105 | .venv 106 | env/ 107 | venv/ 108 | ENV/ 109 | env.bak/ 110 | venv.bak/ 111 | 112 | # Spyder project settings 113 | .spyderproject 114 | .spyproject 115 | 116 | # Rope project settings 117 | .ropeproject 118 | 119 | # mkdocs documentation 120 | /site 121 | 122 | # mypy 123 | .mypy_cache/ 124 | .dmypy.json 125 | dmypy.json 126 | 127 | # Pyre type checker 128 | .pyre/ 129 | 130 | logs/ 131 | imgs/ 132 | samples/ 133 | saved_models/ 134 | data/* 135 | __pycache__ 136 | tmp/ -------------------------------------------------------------------------------- /code/src/sample.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "import os\n", 11 | "from PIL import Image\n", 12 | "import clip\n", 13 | "import os.path as osp\n", 14 | "import os, sys\n", 15 | "import torchvision.utils as vutils\n", 16 | "sys.path.insert(0, '../')\n", 17 | "\n", 18 | "from lib.utils import load_model_weights,mkdir_p\n", 19 | "from models.GALIP import NetG, CLIP_TXT_ENCODER" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 2, 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "device = 'cpu' # 'cpu' # 'cuda:0'\n", 29 | "CLIP_text = \"ViT-B/32\"\n", 30 | "clip_model, preprocess = clip.load(\"ViT-B/32\", device=device)\n", 31 | "clip_model = clip_model.eval()" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": 3, 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "text_encoder = CLIP_TXT_ENCODER(clip_model).to(device)\n", 41 | "netG = NetG(64, 100, 512, 256, 3, False, clip_model).to(device)\n", 42 | "path = '../saved_models/pretrained/pre_cc12m.pth'\n", 43 | "checkpoint = torch.load(path, map_location=torch.device('cpu'))\n", 44 | "netG = load_model_weights(netG, checkpoint['model']['netG'], multi_gpus=False)" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 4, 50 | "metadata": {}, 51 | "outputs": [], 52 | "source": [ 53 | "batch_size = 8\n", 54 | "noise = torch.randn((batch_size, 100)).to(device)" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 5, 60 | "metadata": {}, 61 | "outputs": [], 62 | "source": [ 63 | "captions = ['Line drawing illustration of a kawaii cute ghost.']" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": 6, 69 | "metadata": {}, 70 | "outputs": [], 71 | "source": [ 72 | "mkdir_p('./samples')" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": 7, 78 | "metadata": {}, 79 | "outputs": [], 80 | "source": [ 81 | "# generate from text\n", 82 | "with torch.no_grad():\n", 83 | " for i in range(len(captions)):\n", 84 | " caption = captions[i]\n", 85 | " tokenized_text = clip.tokenize([caption]).to(device)\n", 86 | " sent_emb, word_emb = text_encoder(tokenized_text)\n", 87 | " sent_emb = sent_emb.repeat(batch_size,1)\n", 88 | " fake_imgs = netG(noise,sent_emb,eval=True).float()\n", 89 | " name = f'{captions[i].replace(\" \", \"-\")}'\n", 90 | " vutils.save_image(fake_imgs.data, '../samples/%s.png'%(name), nrow=8, value_range=(-1, 1), normalize=True)" 91 | ] 92 | } 93 | ], 94 | "metadata": { 95 | "kernelspec": { 96 | "display_name": "dfgan", 97 | "language": "python", 98 | "name": "python3" 99 | }, 100 | "language_info": { 101 | "codemirror_mode": { 102 | "name": "ipython", 103 | "version": 3 104 | }, 105 | "file_extension": ".py", 106 | "mimetype": "text/x-python", 107 | "name": "python", 108 | "nbconvert_exporter": "python", 109 | "pygments_lexer": "ipython3", 110 | "version": "3.9.0" 111 | }, 112 | "orig_nbformat": 4, 113 | "vscode": { 114 | "interpreter": { 115 | "hash": "849434eb86c3997df801551b732438d01b491303b69c29efcf332721ce6d8430" 116 | } 117 | } 118 | }, 119 | "nbformat": 4, 120 | "nbformat_minor": 2 121 | } 122 | -------------------------------------------------------------------------------- /code/lib/perpare.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import os.path as osp 3 | import numpy as np 4 | from PIL import Image 5 | from tqdm import tqdm, trange 6 | import torch 7 | import torch.nn.functional as F 8 | import torchvision.transforms as transforms 9 | import torchvision.utils as vutils 10 | from torchvision.utils import save_image, make_grid 11 | from torch.utils.data import DataLoader, random_split 12 | from torch.utils.data.distributed import DistributedSampler 13 | import clip 14 | import importlib 15 | from lib.utils import choose_model 16 | 17 | 18 | ########### preparation ############ 19 | def load_clip(clip_info, device): 20 | import clip as clip 21 | model = clip.load(clip_info['type'], device=device)[0] 22 | return model 23 | 24 | 25 | def prepare_models(args): 26 | device = args.device 27 | local_rank = args.local_rank 28 | multi_gpus = args.multi_gpus 29 | CLIP4trn = load_clip(args.clip4trn, device).eval() 30 | CLIP4evl = load_clip(args.clip4evl, device).eval() 31 | NetG,NetD,NetC,CLIP_IMG_ENCODER,CLIP_TXT_ENCODER = choose_model(args.model) 32 | # image encoder 33 | CLIP_img_enc = CLIP_IMG_ENCODER(CLIP4trn).to(device) 34 | for p in CLIP_img_enc.parameters(): 35 | p.requires_grad = False 36 | CLIP_img_enc.eval() 37 | # text encoder 38 | CLIP_txt_enc = CLIP_TXT_ENCODER(CLIP4trn).to(device) 39 | for p in CLIP_txt_enc.parameters(): 40 | p.requires_grad = False 41 | CLIP_txt_enc.eval() 42 | # GAN models 43 | netG = NetG(args.nf, args.z_dim, args.cond_dim, args.imsize, args.ch_size, args.mixed_precision, CLIP4trn).to(device) 44 | netD = NetD(args.nf, args.imsize, args.ch_size, args.mixed_precision).to(device) 45 | netC = NetC(args.nf, args.cond_dim, args.mixed_precision).to(device) 46 | if (args.multi_gpus) and (args.train): 47 | print("Let's use", torch.cuda.device_count(), "GPUs!") 48 | netG = torch.nn.parallel.DistributedDataParallel(netG, broadcast_buffers=False, 49 | device_ids=[local_rank], 50 | output_device=local_rank, find_unused_parameters=True) 51 | netD = torch.nn.parallel.DistributedDataParallel(netD, broadcast_buffers=False, 52 | device_ids=[local_rank], 53 | output_device=local_rank, find_unused_parameters=True) 54 | netC = torch.nn.parallel.DistributedDataParallel(netC, broadcast_buffers=False, 55 | device_ids=[local_rank], 56 | output_device=local_rank, find_unused_parameters=True) 57 | return CLIP4trn, CLIP4evl, CLIP_img_enc, CLIP_txt_enc, netG, netD, netC 58 | 59 | 60 | def prepare_dataset(args, split, transform): 61 | if args.ch_size!=3: 62 | imsize = 256 63 | else: 64 | imsize = args.imsize 65 | if transform is not None: 66 | image_transform = transform 67 | else: 68 | image_transform = transforms.Compose([ 69 | transforms.Resize(int(imsize * 76 / 64)), 70 | transforms.RandomCrop(imsize), 71 | transforms.RandomHorizontalFlip(), 72 | ]) 73 | from lib.datasets import TextImgDataset as Dataset 74 | dataset = Dataset(split=split, transform=image_transform, args=args) 75 | return dataset 76 | 77 | 78 | def prepare_datasets(args, transform): 79 | # train dataset 80 | train_dataset = prepare_dataset(args, split='train', transform=transform) 81 | # test dataset 82 | val_dataset = prepare_dataset(args, split='test', transform=transform) 83 | return train_dataset, val_dataset 84 | 85 | 86 | def prepare_dataloaders(args, transform=None): 87 | batch_size = args.batch_size 88 | num_workers = args.num_workers 89 | train_dataset, valid_dataset = prepare_datasets(args, transform) 90 | # train dataloader 91 | if args.multi_gpus==True: 92 | train_sampler = DistributedSampler(train_dataset) 93 | train_dataloader = torch.utils.data.DataLoader( 94 | train_dataset, batch_size=batch_size, drop_last=True, 95 | num_workers=num_workers, sampler=train_sampler) 96 | else: 97 | train_sampler = None 98 | train_dataloader = torch.utils.data.DataLoader( 99 | train_dataset, batch_size=batch_size, drop_last=True, 100 | num_workers=num_workers, shuffle='True') 101 | # valid dataloader 102 | if args.multi_gpus==True: 103 | valid_sampler = DistributedSampler(valid_dataset) 104 | valid_dataloader = torch.utils.data.DataLoader( 105 | valid_dataset, batch_size=batch_size, drop_last=True, 106 | num_workers=num_workers, sampler=valid_sampler) 107 | else: 108 | valid_dataloader = torch.utils.data.DataLoader( 109 | valid_dataset, batch_size=batch_size, drop_last=True, 110 | num_workers=num_workers, shuffle='True') 111 | return train_dataloader, valid_dataloader, \ 112 | train_dataset, valid_dataset, train_sampler 113 | 114 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![Visitors](https://visitor-badge.glitch.me/badge?page_id=tobran/GALIP) 2 | [![License CC BY-NC-SA 4.0](https://img.shields.io/badge/license-CC4.0-blue.svg)](https://github.com/tobran/GALIP/blob/master/LICENSE.md) 3 | ![Python 3.9](https://img.shields.io/badge/python-3.9-green.svg) 4 | ![Packagist](https://img.shields.io/badge/Pytorch-1.9.0-red.svg) 5 | ![hardware](https://img.shields.io/badge/GPU-CPU-1abc9c.svg) 6 | ![Last Commit](https://img.shields.io/github/last-commit/tobran/GALIP) 7 | [![Maintenance](https://img.shields.io/badge/Maintained%3F-yes-blue.svg)]((https://github.com/tobran/GALIP/graphs/commit-activity)) 8 | ![Ask Me Anything !](https://img.shields.io/badge/Ask%20me-anything-1a009c.svg) 9 | 10 | # GALIP: Generative Adversarial CLIPs for Text-to-Image Synthesis (CVPR 2023) 11 | 12 |

13 | 14 |

15 | 16 | # A high-quality, fast, and efficient text-to-image model 17 | 18 | Official Pytorch implementation for our paper [GALIP: Generative Adversarial CLIPs for Text-to-Image Synthesis](https://arxiv.org/abs/2301.12959) by [Ming Tao](https://scholar.google.com/citations?user=5GlOlNUAAAAJ), [Bing-Kun Bao](https://scholar.google.com/citations?user=lDppvmoAAAAJ&hl=en), [Hao Tang](https://scholar.google.com/citations?user=9zJkeEMAAAAJ&hl=en), [Changsheng Xu](https://scholar.google.com/citations?user=hI9NRDkAAAAJ). 19 | 20 |

21 | Generated Images 22 |

23 |

24 | 25 |

26 | 27 | 28 | ## Requirements 29 | - python 3.9 30 | - Pytorch 1.9 31 | - At least 1x24GB 3090 GPU (for training) 32 | - Only CPU (for sampling) 33 | 34 | GALIP is a small and fast generative model which can generate multiple pictures in one second even on the CPU. 35 | ## Installation 36 | 37 | Clone this repo. 38 | ``` 39 | git clone https://github.com/tobran/GALIP 40 | pip install -r requirements.txt 41 | ``` 42 | Install [CLIP](https://github.com/openai/CLIP) 43 | 44 | ## Preparation (Same as DF-GAN) 45 | ### Datasets 46 | 1. Download the preprocessed metadata for [birds](https://drive.google.com/file/d/1I6ybkR7L64K8hZOraEZDuHh0cCJw5OUj/view?usp=sharing) [coco](https://drive.google.com/file/d/15Fw-gErCEArOFykW3YTnLKpRcPgI_3AB/view?usp=sharing) and extract them to `data/` 47 | 2. Download the [birds](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html) image data. Extract them to `data/birds/` 48 | 3. Download [coco2014](http://cocodataset.org/#download) dataset and extract the images to `data/coco/images/` 49 | 50 | ## Training 51 | ``` 52 | cd GALIP/code/ 53 | ``` 54 | ### Train the GALIP model 55 | - For bird dataset: `bash scripts/train.sh ./cfg/bird.yml` 56 | - For coco dataset: `bash scripts/train.sh ./cfg/coco.yml` 57 | ### Resume training process 58 | If your training process is interrupted unexpectedly, set **state_epoch**, **log_dir**, and **pretrained_model_path** in train.sh to resume training. 59 | 60 | ### TensorBoard 61 | Our code supports automate FID evaluation during training, the results are stored in TensorBoard files under ./logs. You can change the test interval by changing **test_interval** in the YAML file. 62 | 63 | - For bird dataset: `tensorboard --logdir=./code/logs/bird/train --port 8166` 64 | - For coco dataset: `tensorboard --logdir=./code/logs/coco/train --port 8177` 65 | 66 | 67 | ## Evaluation 68 | 69 | ### Download Pretrained Model 70 | - [GALIP for COCO](https://drive.google.com/file/d/1gbfwDeD7ftZmdOFxfffCjKCyYfF4ptdl/view?usp=sharing). Download and save it to `./code/saved_models/pretrained/` 71 | - [GALIP for CC12M](https://drive.google.com/file/d/1VnONvNRjuyHTzuLKBbozZ38-WIt7XZMC/view?usp=sharing). Download and save it to `./code/saved_models/pretrained/` 72 | 73 | ### Evaluate GALIP models 74 | 75 | ``` 76 | cd GALIP/code/ 77 | ``` 78 | set **pretrained_model** in test.sh 79 | - For bird dataset: `bash scripts/test.sh ./cfg/bird.yml` 80 | - For COCO dataset: `bash scripts/test.sh ./cfg/coco.yml` 81 | - For CC12M (zero-shot on COCO) dataset: `bash scripts/test.sh ./cfg/coco.yml` 82 | 83 | ### Performance 84 | The released model achieves better performance than the paper version. 85 | 86 | 87 | | Model | COCO-FID↓ | COCO-CS↑ | CC12M-ZFID↓ | 88 | | --- | --- | --- | --- | 89 | | GALIP(paper) | 5.85 | 0.3338 | 12.54 | 90 | | GALIP(released) | **5.01** | **0.3379** | **12.54** | 91 | 92 | 93 | ## Sampling 94 | 95 | ### Synthesize images from your text descriptions 96 | - the sample.ipynb can be used to sample 97 | 98 | --- 99 | ### Citing GALIP 100 | 101 | If you find GALIP useful in your research, please consider citing our paper: 102 | ``` 103 | 104 | @inproceedings{tao2023galip, 105 | title={GALIP: Generative Adversarial CLIPs for Text-to-Image Synthesis}, 106 | author={Tao, Ming and Bao, Bing-Kun and Tang, Hao and Xu, Changsheng}, 107 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 108 | pages={14214--14223}, 109 | year={2023} 110 | } 111 | 112 | ``` 113 | The code is released for academic research use only. For commercial use, please contact Ming Tao (陶明) (mingtao2000@126.com). 114 | 115 | 116 | **Reference** 117 | - [DF-GAN: A Simple and Effective Baseline for Text-to-Image Synthesis](https://arxiv.org/abs/2008.05865) [[code]](https://github.com/tobran/DF-GAN) 118 | -------------------------------------------------------------------------------- /code/models/inception.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from torchvision import models 4 | 5 | 6 | class InceptionV3(nn.Module): 7 | """Pretrained InceptionV3 network returning feature maps""" 8 | 9 | # Index of default block of inception to return, 10 | # corresponds to output of final average pooling 11 | DEFAULT_BLOCK_INDEX = 3 12 | 13 | # Maps feature dimensionality to their output blocks indices 14 | BLOCK_INDEX_BY_DIM = { 15 | 64: 0, # First max pooling features 16 | 192: 1, # Second max pooling featurs 17 | 768: 2, # Pre-aux classifier features 18 | 2048: 3 # Final average pooling features 19 | } 20 | 21 | def __init__(self, 22 | output_blocks=[DEFAULT_BLOCK_INDEX], 23 | resize_input=True, 24 | normalize_input=True, 25 | requires_grad=False): 26 | """Build pretrained InceptionV3 27 | 28 | Parameters 29 | ---------- 30 | output_blocks : list of int 31 | Indices of blocks to return features of. Possible values are: 32 | - 0: corresponds to output of first max pooling 33 | - 1: corresponds to output of second max pooling 34 | - 2: corresponds to output which is fed to aux classifier 35 | - 3: corresponds to output of final average pooling 36 | resize_input : bool 37 | If true, bilinearly resizes input to width and height 299 before 38 | feeding input to model. As the network without fully connected 39 | layers is fully convolutional, it should be able to handle inputs 40 | of arbitrary size, so resizing might not be strictly needed 41 | normalize_input : bool 42 | If true, normalizes the input to the statistics the pretrained 43 | Inception network expects 44 | requires_grad : bool 45 | If true, parameters of the model require gradient. Possibly useful 46 | for finetuning the network 47 | """ 48 | super(InceptionV3, self).__init__() 49 | 50 | self.resize_input = resize_input 51 | self.normalize_input = normalize_input 52 | self.output_blocks = sorted(output_blocks) 53 | self.last_needed_block = max(output_blocks) 54 | 55 | assert self.last_needed_block <= 3, \ 56 | 'Last possible output block index is 3' 57 | 58 | self.blocks = nn.ModuleList() 59 | 60 | inception = models.inception_v3(pretrained=True) 61 | 62 | # Block 0: input to maxpool1 63 | block0 = [ 64 | inception.Conv2d_1a_3x3, 65 | inception.Conv2d_2a_3x3, 66 | inception.Conv2d_2b_3x3, 67 | nn.MaxPool2d(kernel_size=3, stride=2) 68 | ] 69 | self.blocks.append(nn.Sequential(*block0)) 70 | 71 | # Block 1: maxpool1 to maxpool2 72 | if self.last_needed_block >= 1: 73 | block1 = [ 74 | inception.Conv2d_3b_1x1, 75 | inception.Conv2d_4a_3x3, 76 | nn.MaxPool2d(kernel_size=3, stride=2) 77 | ] 78 | self.blocks.append(nn.Sequential(*block1)) 79 | 80 | # Block 2: maxpool2 to aux classifier 81 | if self.last_needed_block >= 2: 82 | block2 = [ 83 | inception.Mixed_5b, 84 | inception.Mixed_5c, 85 | inception.Mixed_5d, 86 | inception.Mixed_6a, 87 | inception.Mixed_6b, 88 | inception.Mixed_6c, 89 | inception.Mixed_6d, 90 | inception.Mixed_6e, 91 | ] 92 | self.blocks.append(nn.Sequential(*block2)) 93 | 94 | # Block 3: aux classifier to final avgpool 95 | if self.last_needed_block >= 3: 96 | block3 = [ 97 | inception.Mixed_7a, 98 | inception.Mixed_7b, 99 | inception.Mixed_7c, 100 | nn.AdaptiveAvgPool2d(output_size=(1, 1)) 101 | ] 102 | self.blocks.append(nn.Sequential(*block3)) 103 | 104 | for param in self.parameters(): 105 | param.requires_grad = requires_grad 106 | 107 | def forward(self, inp): 108 | """Get Inception feature maps 109 | 110 | Parameters 111 | ---------- 112 | inp : torch.autograd.Variable 113 | Input tensor of shape Bx3xHxW. Values are expected to be in 114 | range (0, 1) 115 | 116 | Returns 117 | ------- 118 | List of torch.autograd.Variable, corresponding to the selected output 119 | block, sorted ascending by index 120 | """ 121 | outp = [] 122 | x = inp 123 | 124 | if self.resize_input: 125 | x = F.upsample(x, size=(299, 299), mode='bilinear', align_corners=True) 126 | 127 | if self.normalize_input: 128 | x = x.clone() 129 | x[:, 0] = x[:, 0] * (0.229 / 0.5) + (0.485 - 0.5) / 0.5 130 | x[:, 1] = x[:, 1] * (0.224 / 0.5) + (0.456 - 0.5) / 0.5 131 | x[:, 2] = x[:, 2] * (0.225 / 0.5) + (0.406 - 0.5) / 0.5 132 | 133 | for idx, block in enumerate(self.blocks): 134 | x = block(x) 135 | if idx in self.output_blocks: 136 | outp.append(x) 137 | 138 | if idx == self.last_needed_block: 139 | break 140 | 141 | return outp 142 | -------------------------------------------------------------------------------- /code/src/test.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import os.path as osp 3 | import time 4 | import random 5 | import argparse 6 | import numpy as np 7 | from PIL import Image 8 | import pprint 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.optim as optim 13 | from torch.autograd import Variable 14 | import torch.backends.cudnn as cudnn 15 | from torchvision.utils import save_image,make_grid 16 | from torch.utils.tensorboard import SummaryWriter 17 | import torchvision.transforms as transforms 18 | import torchvision.utils as vutils 19 | from torch.utils.data.distributed import DistributedSampler 20 | import multiprocessing as mp 21 | 22 | ROOT_PATH = osp.abspath(osp.join(osp.dirname(osp.abspath(__file__)), "..")) 23 | sys.path.insert(0, ROOT_PATH) 24 | from lib.utils import mkdir_p,get_rank,merge_args_yaml,get_time_stamp,save_args 25 | from lib.utils import load_netG,load_npz,save_models 26 | from lib.perpare import prepare_dataloaders 27 | from lib.perpare import prepare_models 28 | from lib.modules import test as test 29 | 30 | 31 | def parse_args(): 32 | # Training settings 33 | parser = argparse.ArgumentParser(description='Text2Img') 34 | parser.add_argument('--cfg', dest='cfg_file', type=str, default='../cfg/coco.yml', 35 | help='optional config file') 36 | parser.add_argument('--num_workers', type=int, default=4, 37 | help='number of workers(default: {0})'.format(mp.cpu_count() - 1)) 38 | parser.add_argument('--stamp', type=str, default='normal', 39 | help='the stamp of model') 40 | parser.add_argument('--pretrained_model_path', type=str, default='model', 41 | help='the model for training') 42 | parser.add_argument('--log_dir', type=str, default='new', 43 | help='file path to log directory') 44 | parser.add_argument('--model', type=str, default='GALIP', 45 | help='the model for training') 46 | parser.add_argument('--state_epoch', type=int, default=100, 47 | help='state epoch') 48 | parser.add_argument('--batch_size', type=int, default=1024, 49 | help='batch size') 50 | parser.add_argument('--train', type=str, default='True', 51 | help='if train model') 52 | parser.add_argument('--mixed_precision', type=str, default='False', 53 | help='if use multi-gpu') 54 | parser.add_argument('--multi_gpus', type=str, default='False', 55 | help='if use multi-gpu') 56 | parser.add_argument('--gpu_id', type=int, default=1, 57 | help='gpu id') 58 | parser.add_argument('--local_rank', default=-1, type=int, 59 | help='node rank for distributed training') 60 | parser.add_argument('--random_sample', action='store_true',default=True, 61 | help='whether to sample the dataset with random sampler') 62 | args = parser.parse_args() 63 | return args 64 | 65 | 66 | def main(args): 67 | time_stamp = get_time_stamp() 68 | stamp = '_'.join([str(args.model),str(args.stamp),str(args.CONFIG_NAME),str(args.imsize),time_stamp]) 69 | log_dir = osp.join(ROOT_PATH, 'logs/{0}'.format(osp.join(str(args.CONFIG_NAME), 'train', stamp))) 70 | if (args.multi_gpus==True) and (get_rank() != 0): 71 | None 72 | else: 73 | mkdir_p(osp.join(ROOT_PATH, 'logs')) 74 | # prepare TensorBoard 75 | if (args.multi_gpus==True) and (get_rank() != 0): 76 | writer = None 77 | else: 78 | writer = SummaryWriter(log_dir) 79 | # Build and load the generator 80 | # prepare dataloader, models, data 81 | train_dl, valid_dl ,train_ds, valid_ds, sampler = prepare_dataloaders(args) 82 | CLIP4trn, CLIP4evl, image_encoder, text_encoder, netG, netD, netC = prepare_models(args) 83 | state_path = args.pretrained_model_path 84 | multi_gpus = args.multi_gpus 85 | m1, s1 = load_npz(args.npz_path) 86 | netG = load_netG(netG, state_path, multi_gpus, args.train) 87 | 88 | save_models(netG, netD, netC, 0, args.multi_gpus, './tmp') 89 | 90 | netG.eval() 91 | FID, TI_score = test(valid_dl, text_encoder, netG, CLIP4evl, args.device, m1, s1, -1, -1, \ 92 | args.sample_times, args.z_dim, args.batch_size) 93 | if (args.multi_gpus==True) and (get_rank() != 0): 94 | None 95 | else: 96 | print('FID: %.2f, CLIP_Score: %.2f' % (FID, TI_score*100)) 97 | 98 | 99 | if __name__ == "__main__": 100 | args = merge_args_yaml(parse_args()) 101 | # set seed 102 | if args.manual_seed is None: 103 | args.manual_seed = 100 104 | #args.manualSeed = random.randint(1, 10000) 105 | random.seed(args.manual_seed) 106 | np.random.seed(args.manual_seed) 107 | torch.manual_seed(args.manual_seed) 108 | if args.cuda: 109 | if args.multi_gpus: 110 | torch.cuda.manual_seed_all(args.manual_seed) 111 | torch.distributed.init_process_group(backend="nccl") 112 | local_rank = torch.distributed.get_rank() 113 | torch.cuda.set_device(local_rank) 114 | args.device = torch.device("cuda", local_rank) 115 | args.local_rank = local_rank 116 | else: 117 | torch.cuda.manual_seed_all(args.manual_seed) 118 | torch.cuda.set_device(args.gpu_id) 119 | args.device = torch.device("cuda") 120 | else: 121 | args.device = torch.device('cpu') 122 | main(args) 123 | 124 | 125 | -------------------------------------------------------------------------------- /code/lib/datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import numpy as np 5 | import pandas as pd 6 | from PIL import Image 7 | import numpy.random as random 8 | if sys.version_info[0] == 2: 9 | import cPickle as pickle 10 | else: 11 | import pickle 12 | import torch 13 | import torch.utils.data as data 14 | from torch.autograd import Variable 15 | import torchvision.transforms as transforms 16 | import clip as clip 17 | 18 | 19 | def get_fix_data(train_dl, test_dl, text_encoder, args): 20 | fixed_image_train, _, _, fixed_sent_train, fixed_word_train, fixed_key_train = get_one_batch_data(train_dl, text_encoder, args) 21 | fixed_image_test, _, _, fixed_sent_test, fixed_word_test, fixed_key_test= get_one_batch_data(test_dl, text_encoder, args) 22 | fixed_image = torch.cat((fixed_image_train, fixed_image_test), dim=0) 23 | fixed_sent = torch.cat((fixed_sent_train, fixed_sent_test), dim=0) 24 | fixed_word = torch.cat((fixed_word_train, fixed_word_test), dim=0) 25 | fixed_noise = torch.randn(fixed_image.size(0), args.z_dim).to(args.device) 26 | return fixed_image, fixed_sent, fixed_word, fixed_noise 27 | 28 | 29 | def get_one_batch_data(dataloader, text_encoder, args): 30 | data = next(iter(dataloader)) 31 | imgs, captions, CLIP_tokens, sent_emb, words_embs, keys = prepare_data(data, text_encoder, args.device) 32 | return imgs, captions, CLIP_tokens, sent_emb, words_embs, keys 33 | 34 | 35 | def prepare_data(data, text_encoder, device): 36 | imgs, captions, CLIP_tokens, keys = data 37 | imgs, CLIP_tokens = imgs.to(device), CLIP_tokens.to(device) 38 | sent_emb, words_embs = encode_tokens(text_encoder, CLIP_tokens) 39 | return imgs, captions, CLIP_tokens, sent_emb, words_embs, keys 40 | 41 | 42 | def encode_tokens(text_encoder, caption): 43 | # encode text 44 | with torch.no_grad(): 45 | sent_emb,words_embs = text_encoder(caption) 46 | sent_emb,words_embs = sent_emb.detach(), words_embs.detach() 47 | return sent_emb, words_embs 48 | 49 | 50 | def get_imgs(img_path, bbox=None, transform=None, normalize=None): 51 | img = Image.open(img_path).convert('RGB') 52 | width, height = img.size 53 | if bbox is not None: 54 | r = int(np.maximum(bbox[2], bbox[3]) * 0.75) 55 | center_x = int((2 * bbox[0] + bbox[2]) / 2) 56 | center_y = int((2 * bbox[1] + bbox[3]) / 2) 57 | y1 = np.maximum(0, center_y - r) 58 | y2 = np.minimum(height, center_y + r) 59 | x1 = np.maximum(0, center_x - r) 60 | x2 = np.minimum(width, center_x + r) 61 | img = img.crop([x1, y1, x2, y2]) 62 | if transform is not None: 63 | img = transform(img) 64 | if normalize is not None: 65 | img = normalize(img) 66 | return img 67 | 68 | 69 | def get_caption(cap_path,clip_info): 70 | eff_captions = [] 71 | with open(cap_path, "r") as f: 72 | captions = f.read().encode('utf-8').decode('utf8').split('\n') 73 | for cap in captions: 74 | if len(cap) != 0: 75 | eff_captions.append(cap) 76 | sent_ix = random.randint(0, len(eff_captions)) 77 | caption = eff_captions[sent_ix] 78 | tokens = clip.tokenize(caption,truncate=True) 79 | return caption, tokens[0] 80 | 81 | 82 | ################################################################ 83 | # Dataset 84 | ################################################################ 85 | class TextImgDataset(data.Dataset): 86 | def __init__(self, split, transform=None, args=None): 87 | self.transform = transform 88 | self.clip4text = args.clip4text 89 | self.data_dir = args.data_dir 90 | self.dataset_name = args.dataset_name 91 | self.norm = transforms.Compose([ 92 | transforms.ToTensor(), 93 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 94 | ]) 95 | self.split=split 96 | 97 | if self.data_dir.find('birds') != -1: 98 | self.bbox = self.load_bbox() 99 | else: 100 | self.bbox = None 101 | self.split_dir = os.path.join(self.data_dir, split) 102 | self.filenames = self.load_filenames(self.data_dir, split) 103 | self.number_example = len(self.filenames) 104 | 105 | def load_bbox(self): 106 | data_dir = self.data_dir 107 | bbox_path = os.path.join(data_dir, 'CUB_200_2011/bounding_boxes.txt') 108 | df_bounding_boxes = pd.read_csv(bbox_path, 109 | delim_whitespace=True, 110 | header=None).astype(int) 111 | # 112 | filepath = os.path.join(data_dir, 'CUB_200_2011/images.txt') 113 | df_filenames = \ 114 | pd.read_csv(filepath, delim_whitespace=True, header=None) 115 | filenames = df_filenames[1].tolist() 116 | print('Total filenames: ', len(filenames), filenames[0]) 117 | # 118 | filename_bbox = {img_file[:-4]: [] for img_file in filenames} 119 | numImgs = len(filenames) 120 | for i in range(0, numImgs): 121 | # bbox = [x-left, y-top, width, height] 122 | bbox = df_bounding_boxes.iloc[i][1:].tolist() 123 | key = filenames[i][:-4] 124 | filename_bbox[key] = bbox 125 | return filename_bbox 126 | 127 | def load_filenames(self, data_dir, split): 128 | filepath = '%s/%s/filenames.pickle' % (data_dir, split) 129 | if os.path.isfile(filepath): 130 | with open(filepath, 'rb') as f: 131 | filenames = pickle.load(f) 132 | print('Load filenames from: %s (%d)' % (filepath, len(filenames))) 133 | else: 134 | filenames = [] 135 | return filenames 136 | 137 | def __getitem__(self, index): 138 | # 139 | key = self.filenames[index] 140 | data_dir = self.data_dir 141 | # 142 | if self.bbox is not None: 143 | bbox = self.bbox[key] 144 | else: 145 | bbox = None 146 | # 147 | if self.dataset_name.lower().find('coco') != -1: 148 | if self.split=='train': 149 | img_name = '%s/images/train2014/jpg/%s.jpg' % (data_dir, key) 150 | text_name = '%s/text/%s.txt' % (data_dir, key) 151 | else: 152 | img_name = '%s/images/val2014/jpg/%s.jpg' % (data_dir, key) 153 | text_name = '%s/text/%s.txt' % (data_dir, key) 154 | elif self.dataset_name.lower().find('cc3m') != -1: 155 | if self.split=='train': 156 | img_name = '%s/images/train/%s.jpg' % (data_dir, key) 157 | text_name = '%s/text/train/%s.txt' % (data_dir, key.split('_')[0]) 158 | else: 159 | img_name = '%s/images/test/%s.jpg' % (data_dir, key) 160 | text_name = '%s/text/test/%s.txt' % (data_dir, key.split('_')[0]) 161 | elif self.dataset_name.lower().find('cc12m') != -1: 162 | if self.split=='train': 163 | img_name = '%s/images/%s.jpg' % (data_dir, key) 164 | text_name = '%s/text/%s.txt' % (data_dir, key.split('_')[0]) 165 | else: 166 | img_name = '%s/images/%s.jpg' % (data_dir, key) 167 | text_name = '%s/text/%s.txt' % (data_dir, key.split('_')[0]) 168 | else: 169 | img_name = '%s/CUB_200_2011/images/%s.jpg' % (data_dir, key) 170 | text_name = '%s/text/%s.txt' % (data_dir, key) 171 | # 172 | imgs = get_imgs(img_name, bbox, self.transform, normalize=self.norm) 173 | caps,tokens = get_caption(text_name,self.clip4text) 174 | return imgs, caps, tokens, key 175 | 176 | def __len__(self): 177 | return len(self.filenames) 178 | 179 | -------------------------------------------------------------------------------- /code/lib/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import errno 4 | import numpy as np 5 | import numpy.random as random 6 | import torch 7 | from torch import distributed as dist 8 | from tqdm import tqdm 9 | import yaml 10 | from easydict import EasyDict as edict 11 | import pprint 12 | import datetime 13 | import dateutil.tz 14 | from PIL import Image 15 | 16 | import importlib 17 | from torchvision.transforms import InterpolationMode 18 | import torch.nn.functional as F 19 | try: 20 | from torchvision.transforms import InterpolationMode 21 | BICUBIC = InterpolationMode.BICUBIC 22 | except ImportError: 23 | BICUBIC = Image.BICUBIC 24 | 25 | 26 | def choose_model(model): 27 | '''choose models 28 | ''' 29 | model = importlib.import_module(".%s"%(model), "models") 30 | NetG, NetD, NetC, CLIP_IMG_ENCODER, CLIP_TXT_ENCODER = model.NetG, model.NetD, model.NetC, model.CLIP_IMG_ENCODER, model.CLIP_TXT_ENCODER 31 | return NetG,NetD,NetC,CLIP_IMG_ENCODER, CLIP_TXT_ENCODER 32 | 33 | 34 | def params_count(model): 35 | model_size = np.sum([p.numel() for p in model.parameters()]).item() 36 | return model_size 37 | 38 | 39 | def get_time_stamp(): 40 | now = datetime.datetime.now(dateutil.tz.tzlocal()) 41 | timestamp = now.strftime('%Y_%m_%d_%H_%M_%S') 42 | return timestamp 43 | 44 | 45 | def mkdir_p(path): 46 | try: 47 | os.makedirs(path) 48 | except OSError as exc: # Python >2.5 49 | if exc.errno == errno.EEXIST and os.path.isdir(path): 50 | pass 51 | else: 52 | raise 53 | 54 | 55 | def load_npz(path): 56 | f = np.load(path) 57 | m, s = f['mu'][:], f['sigma'][:] 58 | f.close() 59 | return m, s 60 | 61 | # config 62 | def load_yaml(filename): 63 | with open(filename, 'r') as f: 64 | cfg = edict(yaml.load(f, Loader=yaml.FullLoader)) 65 | return cfg 66 | 67 | 68 | def str2bool_dict(dict): 69 | for key,value in dict.items(): 70 | if type(value)==str: 71 | if value.lower() in ('yes','true'): 72 | dict[key] = True 73 | elif value.lower() in ('no','false'): 74 | dict[key] = False 75 | else: 76 | None 77 | return dict 78 | 79 | 80 | def merge_args_yaml(args): 81 | if args.cfg_file is not None: 82 | opt = vars(args) 83 | args = load_yaml(args.cfg_file) 84 | args.update(opt) 85 | args = str2bool_dict(args) 86 | args = edict(args) 87 | return args 88 | 89 | 90 | def save_args(save_path, args): 91 | fp = open(save_path, 'w') 92 | fp.write(yaml.dump(args)) 93 | fp.close() 94 | 95 | 96 | def read_txt_file(txt_file): 97 | content = [] 98 | with open(txt_file, "r") as f: 99 | for line in f.readlines(): 100 | line = line.strip('\n') 101 | content.append(line) 102 | return content 103 | 104 | 105 | def get_rank(): 106 | if not dist.is_available(): 107 | return 0 108 | if not dist.is_initialized(): 109 | return 0 110 | return dist.get_rank() 111 | 112 | 113 | def load_opt_weights(optimizer, weights): 114 | optimizer.load_state_dict(weights) 115 | return optimizer 116 | 117 | 118 | def load_models_opt(netG, netD, netC, optim_G, optim_D, path, multi_gpus): 119 | checkpoint = torch.load(path, map_location=torch.device('cpu')) 120 | netG = load_model_weights(netG, checkpoint['model']['netG'], multi_gpus) 121 | netD = load_model_weights(netD, checkpoint['model']['netD'], multi_gpus) 122 | netC = load_model_weights(netC, checkpoint['model']['netC'], multi_gpus) 123 | optim_G = load_opt_weights(optim_G, checkpoint['optimizers']['optimizer_G']) 124 | optim_D = load_opt_weights(optim_D, checkpoint['optimizers']['optimizer_D']) 125 | return netG, netD, netC, optim_G, optim_D 126 | 127 | 128 | def load_models(netG, netD, netC, path): 129 | checkpoint = torch.load(path, map_location=torch.device('cpu')) 130 | netG = load_model_weights(netG, checkpoint['model']['netG']) 131 | netD = load_model_weights(netD, checkpoint['model']['netD']) 132 | netC = load_model_weights(netC, checkpoint['model']['netC']) 133 | return netG, netD, netC 134 | 135 | 136 | def load_netG(netG, path, multi_gpus, train): 137 | checkpoint = torch.load(path, map_location="cpu") 138 | netG = load_model_weights(netG, checkpoint['model']['netG'], multi_gpus, train) 139 | return netG 140 | 141 | 142 | def load_model_weights(model, weights, multi_gpus, train=True): 143 | if list(weights.keys())[0].find('module')==-1: 144 | pretrained_with_multi_gpu = False 145 | else: 146 | pretrained_with_multi_gpu = True 147 | if (multi_gpus==False) or (train==False): 148 | if pretrained_with_multi_gpu: 149 | state_dict = { 150 | key[7:]: value 151 | for key, value in weights.items() 152 | } 153 | else: 154 | state_dict = weights 155 | else: 156 | state_dict = weights 157 | model.load_state_dict(state_dict) 158 | return model 159 | 160 | 161 | def save_models_opt(netG, netD, netC, optG, optD, epoch, multi_gpus, save_path): 162 | if (multi_gpus==True) and (get_rank() != 0): 163 | None 164 | else: 165 | state = {'model': {'netG': netG.state_dict(), 'netD': netD.state_dict(), 'netC': netC.state_dict()}, \ 166 | 'optimizers': {'optimizer_G': optG.state_dict(), 'optimizer_D': optD.state_dict()},\ 167 | 'epoch': epoch} 168 | torch.save(state, '%s/state_epoch_%03d.pth' % (save_path, epoch)) 169 | 170 | 171 | def save_models(netG, netD, netC, epoch, multi_gpus, save_path): 172 | if (multi_gpus==True) and (get_rank() != 0): 173 | None 174 | else: 175 | state = {'model': {'netG': netG.state_dict(), 'netD': netD.state_dict(), 'netC': netC.state_dict()}} 176 | torch.save(state, '%s/state_epoch_%03d.pth' % (save_path, epoch)) 177 | 178 | 179 | def save_checkpoints(netG, netD, netC, optG, optD, scaler_G, scaler_D, epoch, multi_gpus, save_path): 180 | if (multi_gpus==True) and (get_rank() != 0): 181 | None 182 | else: 183 | state = {'model': {'netG': netG.state_dict(), 'netD': netD.state_dict(), 'netC': netC.state_dict()}, \ 184 | 'optimizers': {'optimizer_G': optG.state_dict(), 'optimizer_D': optD.state_dict()},\ 185 | "scalers": {"scaler_G": scaler_G.state_dict(), "scaler_D": scaler_D.state_dict()},\ 186 | 'epoch': epoch} 187 | torch.save(state, '%s/state_epoch_%03d.pth' % (save_path, epoch)) 188 | 189 | 190 | def write_to_txt(filename, contents): 191 | fh = open(filename, 'w') 192 | fh.write(contents) 193 | fh.close() 194 | 195 | 196 | def read_txt_file(txt_file): 197 | # text_file: file path 198 | content = [] 199 | with open(txt_file, "r") as f: 200 | for line in f.readlines(): 201 | line = line.strip('\n') 202 | content.append(line) 203 | return content 204 | 205 | 206 | def save_img(img, path): 207 | im = img.data.cpu().numpy() 208 | # [-1, 1] --> [0, 255] 209 | im = (im + 1.0) * 127.5 210 | im = im.astype(np.uint8) 211 | im = np.transpose(im, (1, 2, 0)) 212 | im = Image.fromarray(im) 213 | im.save(path) 214 | 215 | 216 | def transf_to_CLIP_input(inputs): 217 | device = inputs.device 218 | if len(inputs.size()) != 4: 219 | raise ValueError('Expect the (B, C, X, Y) tensor.') 220 | else: 221 | mean = torch.tensor([0.48145466, 0.4578275, 0.40821073])\ 222 | .unsqueeze(-1).unsqueeze(-1).unsqueeze(0).to(device) 223 | var = torch.tensor([0.26862954, 0.26130258, 0.27577711])\ 224 | .unsqueeze(-1).unsqueeze(-1).unsqueeze(0).to(device) 225 | inputs = F.interpolate(inputs*0.5+0.5, size=(224, 224)) 226 | inputs = ((inputs+1)*0.5-mean)/var 227 | return inputs.float() 228 | 229 | 230 | class dummy_context_mgr(): 231 | def __enter__(self): 232 | return None 233 | 234 | def __exit__(self, exc_type, exc_value, traceback): 235 | return False 236 | -------------------------------------------------------------------------------- /code/src/train.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import os.path as osp 3 | import time 4 | import random 5 | import argparse 6 | import numpy as np 7 | from PIL import Image 8 | import pprint 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.optim as optim 13 | from torch.autograd import Variable 14 | import torch.backends.cudnn as cudnn 15 | from torchvision.utils import save_image,make_grid 16 | from torch.utils.tensorboard import SummaryWriter 17 | import torchvision.transforms as transforms 18 | import torchvision.utils as vutils 19 | from torch.utils.data.distributed import DistributedSampler 20 | import multiprocessing as mp 21 | 22 | ROOT_PATH = osp.abspath(osp.join(osp.dirname(osp.abspath(__file__)), "..")) 23 | sys.path.insert(0, ROOT_PATH) 24 | from lib.utils import mkdir_p,get_rank,merge_args_yaml,get_time_stamp,save_args 25 | from lib.utils import load_models_opt,save_models_opt,save_models,load_npz,params_count 26 | from lib.perpare import prepare_dataloaders,prepare_models 27 | from lib.modules import sample_one_batch as sample, test as test, train as train 28 | from lib.datasets import get_fix_data 29 | 30 | 31 | def parse_args(): 32 | # Training settings 33 | parser = argparse.ArgumentParser(description='Text2Img') 34 | parser.add_argument('--cfg', dest='cfg_file', type=str, default='../cfg/coco.yml', 35 | help='optional config file') 36 | parser.add_argument('--num_workers', type=int, default=4, 37 | help='number of workers(default: {0})'.format(mp.cpu_count() - 1)) 38 | parser.add_argument('--stamp', type=str, default='normal', 39 | help='the stamp of model') 40 | parser.add_argument('--pretrained_model_path', type=str, default='model', 41 | help='the model for training') 42 | parser.add_argument('--log_dir', type=str, default='new', 43 | help='file path to log directory') 44 | parser.add_argument('--model', type=str, default='GALIP', 45 | help='the model for training') 46 | parser.add_argument('--state_epoch', type=int, default=100, 47 | help='state epoch') 48 | parser.add_argument('--batch_size', type=int, default=1024, 49 | help='batch size') 50 | parser.add_argument('--train', type=str, default='True', 51 | help='if train model') 52 | parser.add_argument('--mixed_precision', type=str, default='False', 53 | help='if use multi-gpu') 54 | parser.add_argument('--multi_gpus', type=str, default='False', 55 | help='if use multi-gpu') 56 | parser.add_argument('--gpu_id', type=int, default=1, 57 | help='gpu id') 58 | parser.add_argument('--local_rank', default=-1, type=int, 59 | help='node rank for distributed training') 60 | parser.add_argument('--random_sample', action='store_true',default=True, 61 | help='whether to sample the dataset with random sampler') 62 | args = parser.parse_args() 63 | return args 64 | 65 | 66 | def main(args): 67 | time_stamp = get_time_stamp() 68 | stamp = '_'.join([str(args.model),'nf'+str(args.nf),str(args.stamp),str(args.CONFIG_NAME),str(args.imsize),time_stamp]) 69 | args.model_save_file = osp.join(ROOT_PATH, 'saved_models', str(args.CONFIG_NAME), stamp) 70 | log_dir = args.log_dir 71 | if log_dir == 'new': 72 | log_dir = osp.join(ROOT_PATH, 'logs/{0}'.format(osp.join(str(args.CONFIG_NAME), 'train', stamp))) 73 | args.img_save_dir = osp.join(ROOT_PATH, 'imgs/{0}'.format(osp.join(str(args.CONFIG_NAME), 'train', stamp))) 74 | if (args.multi_gpus==True) and (get_rank() != 0): 75 | None 76 | else: 77 | mkdir_p(osp.join(ROOT_PATH, 'logs')) 78 | mkdir_p(args.model_save_file) 79 | mkdir_p(args.img_save_dir) 80 | # prepare TensorBoard 81 | if (args.multi_gpus==True) and (get_rank() != 0): 82 | writer = None 83 | else: 84 | writer = SummaryWriter(log_dir) 85 | # prepare dataloader, models, data 86 | train_dl, valid_dl ,train_ds, valid_ds, sampler = prepare_dataloaders(args) 87 | CLIP4trn, CLIP4evl, image_encoder, text_encoder, netG, netD, netC = prepare_models(args) 88 | print('**************G_paras: ',params_count(netG)) 89 | print('**************D_paras: ',params_count(netD)+params_count(netC)) 90 | fixed_img, fixed_sent, fixed_words, fixed_z = get_fix_data(train_dl, valid_dl, text_encoder, args) 91 | if (args.multi_gpus==True) and (get_rank() != 0): 92 | None 93 | else: 94 | fixed_grid = make_grid(fixed_img.cpu(), nrow=8, normalize=True) 95 | #writer.add_image('fixed images', fixed_grid, 0) 96 | img_name = 'gt.png' 97 | img_save_path = osp.join(args.img_save_dir, img_name) 98 | vutils.save_image(fixed_img.data, img_save_path, nrow=8, normalize=True) 99 | # prepare optimizer 100 | D_params = list(netD.parameters()) + list(netC.parameters()) 101 | optimizerD = torch.optim.Adam(D_params, lr=args.lr_d, betas=(0.0, 0.9)) 102 | optimizerG = torch.optim.Adam(netG.parameters(), lr=args.lr_g, betas=(0.0, 0.9)) 103 | if args.mixed_precision==True: 104 | scaler_D = torch.cuda.amp.GradScaler(growth_interval=args.growth_interval) 105 | scaler_G = torch.cuda.amp.GradScaler(growth_interval=args.growth_interval) 106 | else: 107 | scaler_D = None 108 | scaler_G = None 109 | m1, s1 = load_npz(args.npz_path) 110 | start_epoch = 1 111 | # load from checkpoint 112 | if args.state_epoch!=1: 113 | start_epoch = args.state_epoch + 1 114 | path = osp.join(args.pretrained_model_path, 'state_epoch_%03d.pth'%(args.state_epoch)) 115 | netG, netD, netC, optimizerG, optimizerD = load_models_opt(netG, netD, netC, optimizerG, optimizerD, path, args.multi_gpus) 116 | #netG, netD, netC = load_model(netG, netD, netC, path, args.multi_gpus) 117 | # print args 118 | if (args.multi_gpus==True) and (get_rank() != 0): 119 | None 120 | else: 121 | pprint.pprint(args) 122 | arg_save_path = osp.join(log_dir, 'args.yaml') 123 | save_args(arg_save_path, args) 124 | print("Start Training") 125 | # Start training 126 | test_interval,gen_interval,save_interval = args.test_interval,args.gen_interval,args.save_interval 127 | #torch.cuda.empty_cache() 128 | # start_epoch = 1 129 | for epoch in range(start_epoch, args.max_epoch, 1): 130 | if (args.multi_gpus==True): 131 | sampler.set_epoch(epoch) 132 | start_t = time.time() 133 | # training 134 | args.current_epoch = epoch 135 | torch.cuda.empty_cache() 136 | train(train_dl, netG, netD, netC, text_encoder, image_encoder, optimizerG, optimizerD, scaler_G, scaler_D, args) 137 | torch.cuda.empty_cache() 138 | # save 139 | if epoch%save_interval==0: 140 | save_models_opt(netG, netD, netC, optimizerG, optimizerD, epoch, args.multi_gpus, args.model_save_file) 141 | torch.cuda.empty_cache() 142 | # sample 143 | if epoch%gen_interval==0: 144 | sample(fixed_z, fixed_sent, netG, args.multi_gpus, epoch, args.img_save_dir, writer) 145 | torch.cuda.empty_cache() 146 | # test 147 | if epoch%test_interval==0: 148 | #torch.cuda.empty_cache() 149 | FID, TI_score = test(valid_dl, text_encoder, netG, CLIP4evl, args.device, m1, s1, epoch, args.max_epoch, args.sample_times, args.z_dim, args.batch_size) 150 | torch.cuda.empty_cache() 151 | if (args.multi_gpus==True) and (get_rank() != 0): 152 | None 153 | else: 154 | if epoch%test_interval==0: 155 | writer.add_scalar('FID', FID, epoch) 156 | writer.add_scalar('CLIP_Score', TI_score, epoch) 157 | print('The %d epoch FID: %.2f, CLIP_Score: %.2f' % (epoch,FID,TI_score*100)) 158 | end_t = time.time() 159 | print('The epoch %d costs %.2fs'%(epoch, end_t-start_t)) 160 | print('*'*40) 161 | 162 | 163 | 164 | if __name__ == "__main__": 165 | args = merge_args_yaml(parse_args()) 166 | # set seed 167 | if args.manual_seed is None: 168 | args.manual_seed = 100 169 | #args.manualSeed = random.randint(1, 10000) 170 | random.seed(args.manual_seed) 171 | np.random.seed(args.manual_seed) 172 | torch.manual_seed(args.manual_seed) 173 | if args.cuda: 174 | if args.multi_gpus: 175 | torch.cuda.manual_seed_all(args.manual_seed) 176 | torch.distributed.init_process_group(backend="nccl") 177 | local_rank = torch.distributed.get_rank() 178 | torch.cuda.set_device(local_rank) 179 | args.device = torch.device("cuda", local_rank) 180 | args.local_rank = local_rank 181 | else: 182 | torch.cuda.manual_seed_all(args.manual_seed) 183 | torch.cuda.set_device(args.gpu_id) 184 | args.device = torch.device("cuda") 185 | else: 186 | args.device = torch.device('cpu') 187 | main(args) 188 | 189 | -------------------------------------------------------------------------------- /code/models/GALIP.py: -------------------------------------------------------------------------------- 1 | ####### 生成器,判别器未采用残差 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | import torch.nn.functional as F 6 | from collections import OrderedDict 7 | from lib.utils import dummy_context_mgr 8 | 9 | 10 | class CLIP_IMG_ENCODER(nn.Module): 11 | def __init__(self, CLIP): 12 | super(CLIP_IMG_ENCODER, self).__init__() 13 | model = CLIP.visual 14 | # print(model) 15 | self.define_module(model) 16 | for param in self.parameters(): 17 | param.requires_grad = False 18 | 19 | def define_module(self, model): 20 | self.conv1 = model.conv1 21 | self.class_embedding = model.class_embedding 22 | self.positional_embedding = model.positional_embedding 23 | self.ln_pre = model.ln_pre 24 | self.transformer = model.transformer 25 | self.ln_post = model.ln_post 26 | self.proj = model.proj 27 | 28 | @property 29 | def dtype(self): 30 | return self.conv1.weight.dtype 31 | 32 | def transf_to_CLIP_input(self,inputs): 33 | device = inputs.device 34 | if len(inputs.size()) != 4: 35 | raise ValueError('Expect the (B, C, X, Y) tensor.') 36 | else: 37 | mean = torch.tensor([0.48145466, 0.4578275, 0.40821073])\ 38 | .unsqueeze(-1).unsqueeze(-1).unsqueeze(0).to(device) 39 | var = torch.tensor([0.26862954, 0.26130258, 0.27577711])\ 40 | .unsqueeze(-1).unsqueeze(-1).unsqueeze(0).to(device) 41 | inputs = F.interpolate(inputs*0.5+0.5, size=(224, 224)) 42 | inputs = ((inputs+1)*0.5-mean)/var 43 | return inputs 44 | 45 | def forward(self, img: torch.Tensor): 46 | x = self.transf_to_CLIP_input(img) 47 | x = x.type(self.dtype) 48 | x = self.conv1(x) # shape = [*, width, grid, grid] 49 | grid = x.size(-1) 50 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 51 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 52 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] 53 | x = x + self.positional_embedding.to(x.dtype) 54 | x = self.ln_pre(x) 55 | # NLD -> LND 56 | x = x.permute(1, 0, 2) 57 | # Local features 58 | #selected = [1,4,7,12] 59 | selected = [1,4,8] 60 | local_features = [] 61 | for i in range(12): 62 | x = self.transformer.resblocks[i](x) 63 | if i in selected: 64 | local_features.append(x.permute(1, 0, 2)[:, 1:, :].permute(0, 2, 1).reshape(-1, 768, grid, grid).contiguous().type(img.dtype)) 65 | x = x.permute(1, 0, 2) # LND -> NLD 66 | x = self.ln_post(x[:, 0, :]) 67 | if self.proj is not None: 68 | x = x @ self.proj 69 | return torch.stack(local_features, dim=1), x.type(img.dtype) 70 | 71 | 72 | class CLIP_TXT_ENCODER(nn.Module): 73 | def __init__(self, CLIP): 74 | super(CLIP_TXT_ENCODER, self).__init__() 75 | self.define_module(CLIP) 76 | # print(model) 77 | for param in self.parameters(): 78 | param.requires_grad = False 79 | 80 | def define_module(self, CLIP): 81 | self.transformer = CLIP.transformer 82 | self.vocab_size = CLIP.vocab_size 83 | self.token_embedding = CLIP.token_embedding 84 | self.positional_embedding = CLIP.positional_embedding 85 | self.ln_final = CLIP.ln_final 86 | self.text_projection = CLIP.text_projection 87 | 88 | @property 89 | def dtype(self): 90 | return self.transformer.resblocks[0].mlp.c_fc.weight.dtype 91 | 92 | def forward(self, text): 93 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 94 | x = x + self.positional_embedding.type(self.dtype) 95 | x = x.permute(1, 0, 2) # NLD -> LND 96 | x = self.transformer(x) 97 | x = x.permute(1, 0, 2) # LND -> NLD 98 | x = self.ln_final(x).type(self.dtype) 99 | # x.shape = [batch_size, n_ctx, transformer.width] 100 | # take features from the eot embedding (eot_token is the highest number in each sequence) 101 | sent_emb = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 102 | return sent_emb, x 103 | 104 | 105 | class CLIP_Mapper(nn.Module): 106 | def __init__(self, CLIP): 107 | super(CLIP_Mapper, self).__init__() 108 | model = CLIP.visual 109 | # print(model) 110 | self.define_module(model) 111 | for param in model.parameters(): 112 | param.requires_grad = False 113 | 114 | def define_module(self, model): 115 | self.conv1 = model.conv1 116 | self.class_embedding = model.class_embedding 117 | self.positional_embedding = model.positional_embedding 118 | self.ln_pre = model.ln_pre 119 | self.transformer = model.transformer 120 | 121 | @property 122 | def dtype(self): 123 | return self.conv1.weight.dtype 124 | 125 | def forward(self, img: torch.Tensor, prompts: torch.Tensor): 126 | x = img.type(self.dtype) 127 | prompts = prompts.type(self.dtype) 128 | grid = x.size(-1) 129 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 130 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 131 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) 132 | # shape = [*, grid ** 2 + 1, width] 133 | x = x + self.positional_embedding.to(x.dtype) 134 | x = self.ln_pre(x) 135 | # NLD -> LND 136 | x = x.permute(1, 0, 2) 137 | # Local features 138 | selected = [1,2,3,4,5,6,7,8] 139 | begin, end = 0, 12 140 | prompt_idx = 0 141 | for i in range(begin, end): 142 | if i in selected: 143 | prompt = prompts[:,prompt_idx,:].unsqueeze(0) 144 | prompt_idx = prompt_idx+1 145 | x = torch.cat((x,prompt), dim=0) 146 | x = self.transformer.resblocks[i](x) 147 | x = x[:-1,:,:] 148 | else: 149 | x = self.transformer.resblocks[i](x) 150 | return x.permute(1, 0, 2)[:, 1:, :].permute(0, 2, 1).reshape(-1, 768, grid, grid).contiguous().type(img.dtype) 151 | 152 | 153 | class CLIP_Adapter(nn.Module): 154 | def __init__(self, in_ch, mid_ch, out_ch, G_ch, CLIP_ch, cond_dim, k, s, p, map_num, CLIP): 155 | super(CLIP_Adapter, self).__init__() 156 | self.CLIP_ch = CLIP_ch 157 | self.FBlocks = nn.ModuleList([]) 158 | self.FBlocks.append(M_Block(in_ch, mid_ch, out_ch, cond_dim, k, s, p)) 159 | for i in range(map_num-1): 160 | self.FBlocks.append(M_Block(out_ch, mid_ch, out_ch, cond_dim, k, s, p)) 161 | self.conv_fuse = nn.Conv2d(out_ch, CLIP_ch, 5, 1, 2) 162 | self.CLIP_ViT = CLIP_Mapper(CLIP) 163 | self.conv = nn.Conv2d(768, G_ch, 5, 1, 2) 164 | # 165 | self.fc_prompt = nn.Linear(cond_dim, CLIP_ch*8) 166 | 167 | def forward(self,out,c): 168 | prompts = self.fc_prompt(c).view(c.size(0),-1,self.CLIP_ch) 169 | for FBlock in self.FBlocks: 170 | out = FBlock(out,c) 171 | fuse_feat = self.conv_fuse(out) 172 | map_feat = self.CLIP_ViT(fuse_feat,prompts) 173 | return self.conv(fuse_feat+0.1*map_feat) 174 | 175 | 176 | class NetG(nn.Module): 177 | def __init__(self, ngf, nz, cond_dim, imsize, ch_size, mixed_precision, CLIP): 178 | super(NetG, self).__init__() 179 | self.ngf = ngf 180 | self.mixed_precision = mixed_precision 181 | # build CLIP Mapper 182 | self.code_sz, self.code_ch, self.mid_ch = 7, 64, 32 183 | self.CLIP_ch = 768 184 | self.fc_code = nn.Linear(nz, self.code_sz*self.code_sz*self.code_ch) 185 | self.mapping = CLIP_Adapter(self.code_ch, self.mid_ch, self.code_ch, ngf*8, self.CLIP_ch, cond_dim+nz, 3, 1, 1, 4, CLIP) 186 | # build GBlocks 187 | self.GBlocks = nn.ModuleList([]) 188 | in_out_pairs = list(get_G_in_out_chs(ngf, imsize)) 189 | imsize = 4 190 | for idx, (in_ch, out_ch) in enumerate(in_out_pairs): 191 | if idx<(len(in_out_pairs)-1): 192 | imsize = imsize*2 193 | else: 194 | imsize = 224 195 | self.GBlocks.append(G_Block(cond_dim+nz, in_ch, out_ch, imsize)) 196 | # to RGB image 197 | self.to_rgb = nn.Sequential( 198 | nn.LeakyReLU(0.2,inplace=True), 199 | nn.Conv2d(out_ch, ch_size, 3, 1, 1), 200 | #nn.Tanh(), 201 | ) 202 | 203 | def forward(self, noise, c, eval=False): # x=noise, c=ent_emb 204 | with torch.cuda.amp.autocast() if self.mixed_precision and not eval else dummy_context_mgr() as mp: 205 | cond = torch.cat((noise, c), dim=1) 206 | out = self.mapping(self.fc_code(noise).view(noise.size(0), self.code_ch, self.code_sz, self.code_sz), cond) 207 | # fuse text and visual features 208 | for GBlock in self.GBlocks: 209 | out = GBlock(out, cond) 210 | # convert to RGB image 211 | out = self.to_rgb(out) 212 | return out 213 | 214 | 215 | # 定义鉴别器网络D 216 | class NetD(nn.Module): 217 | def __init__(self, ndf, imsize, ch_size, mixed_precision): 218 | super(NetD, self).__init__() 219 | self.mixed_precision = mixed_precision 220 | self.DBlocks = nn.ModuleList([ 221 | D_Block(768, 768, 3, 1, 1, res=True, CLIP_feat=True), 222 | D_Block(768, 768, 3, 1, 1, res=True, CLIP_feat=True), 223 | ]) 224 | self.main = D_Block(768, 512, 3, 1, 1, res=True, CLIP_feat=False) 225 | 226 | def forward(self, h): 227 | with torch.cuda.amp.autocast() if self.mixed_precision else dummy_context_mgr() as mpc: 228 | out = h[:,0] 229 | for idx in range(len(self.DBlocks)): 230 | out = self.DBlocks[idx](out, h[:,idx+1]) 231 | out = self.main(out) 232 | return out 233 | 234 | 235 | class NetC(nn.Module): 236 | def __init__(self, ndf, cond_dim, mixed_precision): 237 | super(NetC, self).__init__() 238 | self.cond_dim = cond_dim 239 | self.mixed_precision = mixed_precision 240 | self.joint_conv = nn.Sequential( 241 | nn.Conv2d(512+512, 128, 4, 1, 0, bias=False), 242 | nn.LeakyReLU(0.2, inplace=True), 243 | nn.Conv2d(128, 1, 4, 1, 0, bias=False), 244 | ) 245 | 246 | def forward(self, out, cond): 247 | with torch.cuda.amp.autocast() if self.mixed_precision else dummy_context_mgr() as mpc: 248 | cond = cond.view(-1, self.cond_dim, 1, 1) 249 | cond = cond.repeat(1, 1, 7, 7) 250 | h_c_code = torch.cat((out, cond), 1) 251 | out = self.joint_conv(h_c_code) 252 | return out 253 | 254 | 255 | class M_Block(nn.Module): 256 | def __init__(self, in_ch, mid_ch, out_ch, cond_dim, k, s, p): 257 | super(M_Block, self).__init__() 258 | self.conv1 = nn.Conv2d(in_ch, mid_ch, k, s, p) 259 | self.fuse1 = DFBLK(cond_dim, mid_ch) 260 | self.conv2 = nn.Conv2d(mid_ch, out_ch, k, s, p) 261 | self.fuse2 = DFBLK(cond_dim, out_ch) 262 | self.learnable_sc = in_ch != out_ch 263 | if self.learnable_sc: 264 | self.c_sc = nn.Conv2d(in_ch, out_ch, 1, stride=1, padding=0) 265 | 266 | def shortcut(self, x): 267 | if self.learnable_sc: 268 | x = self.c_sc(x) 269 | return x 270 | 271 | def residual(self, h, text): 272 | h = self.conv1(h) 273 | h = self.fuse1(h, text) 274 | h = self.conv2(h) 275 | h = self.fuse2(h, text) 276 | return h 277 | 278 | def forward(self, h, c): 279 | return self.shortcut(h) + self.residual(h, c) 280 | 281 | 282 | class G_Block(nn.Module): 283 | def __init__(self, cond_dim, in_ch, out_ch, imsize): 284 | super(G_Block, self).__init__() 285 | self.imsize = imsize 286 | self.learnable_sc = in_ch != out_ch 287 | self.c1 = nn.Conv2d(in_ch, out_ch, 3, 1, 1) 288 | self.c2 = nn.Conv2d(out_ch, out_ch, 3, 1, 1) 289 | self.fuse1 = DFBLK(cond_dim, in_ch) 290 | self.fuse2 = DFBLK(cond_dim, out_ch) 291 | if self.learnable_sc: 292 | self.c_sc = nn.Conv2d(in_ch,out_ch, 1, stride=1, padding=0) 293 | 294 | def shortcut(self, x): 295 | if self.learnable_sc: 296 | x = self.c_sc(x) 297 | return x 298 | 299 | def residual(self, h, y): 300 | h = self.fuse1(h, y) 301 | h = self.c1(h) 302 | h = self.fuse2(h, y) 303 | h = self.c2(h) 304 | return h 305 | 306 | def forward(self, h, y): 307 | h = F.interpolate(h, size=(self.imsize, self.imsize)) 308 | return self.shortcut(h) + self.residual(h, y) 309 | 310 | 311 | class D_Block(nn.Module): 312 | def __init__(self, fin, fout, k, s, p, res, CLIP_feat): 313 | super(D_Block, self).__init__() 314 | self.res, self.CLIP_feat = res, CLIP_feat 315 | self.learned_shortcut = (fin != fout) 316 | self.conv_r = nn.Sequential( 317 | nn.Conv2d(fin, fout, k, s, p, bias=False), 318 | nn.LeakyReLU(0.2, inplace=True), 319 | nn.Conv2d(fout, fout, k, s, p, bias=False), 320 | nn.LeakyReLU(0.2, inplace=True), 321 | ) 322 | self.conv_s = nn.Conv2d(fin, fout, 1, stride=1, padding=0) 323 | if self.res==True: 324 | self.gamma = nn.Parameter(torch.zeros(1)) 325 | if self.CLIP_feat==True: 326 | self.beta = nn.Parameter(torch.zeros(1)) 327 | 328 | def forward(self, x, CLIP_feat=None): 329 | res = self.conv_r(x) 330 | if self.learned_shortcut: 331 | x = self.conv_s(x) 332 | if (self.res==True)and(self.CLIP_feat==True): 333 | return x + self.gamma*res + self.beta*CLIP_feat 334 | elif (self.res==True)and(self.CLIP_feat!=True): 335 | return x + self.gamma*res 336 | elif (self.res!=True)and(self.CLIP_feat==True): 337 | return x + self.beta*CLIP_feat 338 | else: 339 | return x 340 | 341 | 342 | class DFBLK(nn.Module): 343 | def __init__(self, cond_dim, in_ch): 344 | super(DFBLK, self).__init__() 345 | self.affine0 = Affine(cond_dim, in_ch) 346 | self.affine1 = Affine(cond_dim, in_ch) 347 | 348 | def forward(self, x, y=None): 349 | h = self.affine0(x, y) 350 | h = nn.LeakyReLU(0.2,inplace=True)(h) 351 | h = self.affine1(h, y) 352 | h = nn.LeakyReLU(0.2,inplace=True)(h) 353 | return h 354 | 355 | 356 | class QuickGELU(nn.Module): 357 | def forward(self, x: torch.Tensor): 358 | return x * torch.sigmoid(1.702 * x) 359 | 360 | 361 | class Affine(nn.Module): 362 | def __init__(self, cond_dim, num_features): 363 | super(Affine, self).__init__() 364 | 365 | self.fc_gamma = nn.Sequential(OrderedDict([ 366 | ('linear1',nn.Linear(cond_dim, num_features)), 367 | ('relu1',nn.ReLU(inplace=True)), 368 | ('linear2',nn.Linear(num_features, num_features)), 369 | ])) 370 | self.fc_beta = nn.Sequential(OrderedDict([ 371 | ('linear1',nn.Linear(cond_dim, num_features)), 372 | ('relu1',nn.ReLU(inplace=True)), 373 | ('linear2',nn.Linear(num_features, num_features)), 374 | ])) 375 | self._initialize() 376 | 377 | def _initialize(self): 378 | nn.init.zeros_(self.fc_gamma.linear2.weight.data) 379 | nn.init.ones_(self.fc_gamma.linear2.bias.data) 380 | nn.init.zeros_(self.fc_beta.linear2.weight.data) 381 | nn.init.zeros_(self.fc_beta.linear2.bias.data) 382 | 383 | def forward(self, x, y=None): 384 | weight = self.fc_gamma(y) 385 | bias = self.fc_beta(y) 386 | 387 | if weight.dim() == 1: 388 | weight = weight.unsqueeze(0) 389 | if bias.dim() == 1: 390 | bias = bias.unsqueeze(0) 391 | 392 | size = x.size() 393 | weight = weight.unsqueeze(-1).unsqueeze(-1).expand(size) 394 | bias = bias.unsqueeze(-1).unsqueeze(-1).expand(size) 395 | return weight * x + bias 396 | 397 | 398 | def get_G_in_out_chs(nf, imsize): 399 | layer_num = int(np.log2(imsize))-1 400 | channel_nums = [nf*min(2**idx, 8) for idx in range(layer_num)] 401 | channel_nums = channel_nums[::-1] 402 | in_out_pairs = zip(channel_nums[:-1], channel_nums[1:]) 403 | return in_out_pairs 404 | 405 | 406 | def get_D_in_out_chs(nf, imsize): 407 | layer_num = int(np.log2(imsize))-1 408 | channel_nums = [nf*min(2**idx, 8) for idx in range(layer_num)] 409 | in_out_pairs = zip(channel_nums[:-1], channel_nums[1:]) 410 | return in_out_pairs 411 | -------------------------------------------------------------------------------- /code/lib/modules.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | from pyexpat import features 3 | import os.path as osp 4 | import time 5 | import random 6 | import datetime 7 | import argparse 8 | from scipy import linalg 9 | import numpy as np 10 | from PIL import Image 11 | from tqdm import tqdm, trange 12 | from torch.autograd import Variable 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | import torch.optim as optim 17 | import torch.backends.cudnn as cudnn 18 | import torchvision.transforms as transforms 19 | import torchvision.utils as vutils 20 | from torchvision.utils import make_grid 21 | from lib.utils import transf_to_CLIP_input, dummy_context_mgr 22 | from lib.utils import mkdir_p, get_rank 23 | from lib.datasets import prepare_data 24 | 25 | from models.inception import InceptionV3 26 | from torch.nn.functional import adaptive_avg_pool2d 27 | import torch.distributed as dist 28 | 29 | 30 | ############ GAN ############ 31 | def train(dataloader, netG, netD, netC, text_encoder, image_encoder, optimizerG, optimizerD, scaler_G, scaler_D, args): 32 | batch_size = args.batch_size 33 | device = args.device 34 | epoch = args.current_epoch 35 | max_epoch = args.max_epoch 36 | z_dim = args.z_dim 37 | netG, netD, netC, image_encoder = netG.train(), netD.train(), netC.train(), image_encoder.train() 38 | if (args.multi_gpus==True) and (get_rank() != 0): 39 | None 40 | else: 41 | loop = tqdm(total=len(dataloader)) 42 | for step, data in enumerate(dataloader, 0): 43 | ############## 44 | # Train D 45 | ############## 46 | optimizerD.zero_grad() 47 | with torch.cuda.amp.autocast() if args.mixed_precision else dummy_context_mgr() as mpc: 48 | # prepare_data 49 | real, captions, CLIP_tokens, sent_emb, words_embs, keys = prepare_data(data, text_encoder, device) 50 | real = real.requires_grad_() 51 | sent_emb = sent_emb.requires_grad_() 52 | words_embs = words_embs.requires_grad_() 53 | # predict real 54 | CLIP_real,real_emb = image_encoder(real) 55 | real_feats = netD(CLIP_real) 56 | pred_real, errD_real = predict_loss(netC, real_feats, sent_emb, negtive=False) 57 | # predict mismatch 58 | mis_sent_emb = torch.cat((sent_emb[1:], sent_emb[0:1]), dim=0).detach() 59 | _, errD_mis = predict_loss(netC, real_feats, mis_sent_emb, negtive=True) 60 | # synthesize fake images 61 | noise = torch.randn(batch_size, z_dim).to(device) 62 | fake = netG(noise, sent_emb) 63 | CLIP_fake, fake_emb = image_encoder(fake) 64 | fake_feats = netD(CLIP_fake.detach()) 65 | _, errD_fake = predict_loss(netC, fake_feats, sent_emb, negtive=True) 66 | # MA-GP 67 | if args.mixed_precision: 68 | errD_MAGP = MA_GP_MP(CLIP_real, sent_emb, pred_real, scaler_D) 69 | else: 70 | errD_MAGP = MA_GP_FP32(CLIP_real, sent_emb, pred_real) 71 | # whole D loss 72 | with torch.cuda.amp.autocast() if args.mixed_precision else dummy_context_mgr() as mpc: 73 | errD = errD_real + (errD_fake + errD_mis)/2.0 + errD_MAGP 74 | # update D 75 | if args.mixed_precision: 76 | scaler_D.scale(errD).backward() 77 | scaler_D.step(optimizerD) 78 | scaler_D.update() 79 | if scaler_D.get_scale() [0, 255] 204 | im = (im + 1.0) * 127.5 205 | im = im.astype(np.uint8) 206 | im = np.transpose(im, (1, 2, 0)) 207 | im = Image.fromarray(im) 208 | ###################################################### 209 | # (3) Save fake images 210 | ###################################################### 211 | if multi_gpus==True: 212 | single_img_name = 'batch_%04d.png'%(j) 213 | single_img_save_dir = osp.join(save_dir, 'single', str('gpu%d'%(get_rank())), 'step%04d'%(step)) 214 | single_img_save_name = osp.join(single_img_save_dir, single_img_name) 215 | else: 216 | single_img_name = 'step_%04d.png'%(step) 217 | single_img_save_dir = osp.join(save_dir, 'single', 'step%04d'%(step)) 218 | single_img_save_name = osp.join(single_img_save_dir, single_img_name) 219 | mkdir_p(single_img_save_dir) 220 | im.save(single_img_save_name) 221 | if (multi_gpus==True) and (get_rank() != 0): 222 | None 223 | else: 224 | print('Step: %d' % (step)) 225 | 226 | 227 | def calculate_FID_CLIP_sim(dataloader, text_encoder, netG, CLIP, device, m1, s1, epoch, max_epoch, times, z_dim, batch_size): 228 | """ Calculates the FID """ 229 | clip_cos = torch.FloatTensor([0.0]).to(device) 230 | # prepare Inception V3 231 | dims = 2048 232 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] 233 | model = InceptionV3([block_idx]) 234 | model.to(device) 235 | model.eval() 236 | netG.eval() 237 | norm = transforms.Compose([ 238 | transforms.Normalize((-1, -1, -1), (2, 2, 2)), 239 | transforms.Resize((299, 299)), 240 | ]) 241 | n_gpu = dist.get_world_size() 242 | dl_length = dataloader.__len__() 243 | imgs_num = dl_length * n_gpu * batch_size * times 244 | pred_arr = np.empty((imgs_num, dims)) 245 | if (n_gpu!=1) and (get_rank() != 0): 246 | None 247 | else: 248 | loop = tqdm(total=int(dl_length*times)) 249 | for time in range(times): 250 | for i, data in enumerate(dataloader): 251 | start = i * batch_size * n_gpu + time * dl_length * n_gpu * batch_size 252 | end = start + batch_size * n_gpu 253 | ###################################################### 254 | # (1) Prepare_data 255 | ###################################################### 256 | imgs, captions, CLIP_tokens, sent_emb, words_embs, keys = prepare_data(data, text_encoder, device) 257 | ###################################################### 258 | # (2) Generate fake images 259 | ###################################################### 260 | batch_size = sent_emb.size(0) 261 | netG.eval() 262 | with torch.no_grad(): 263 | noise = torch.randn(batch_size, z_dim).to(device) 264 | fake_imgs = netG(noise,sent_emb,eval=True).float() 265 | # norm_ip(fake_imgs, -1, 1) 266 | fake_imgs = torch.clamp(fake_imgs, -1., 1.) 267 | fake_imgs = torch.nan_to_num(fake_imgs, nan=-1.0, posinf=1.0, neginf=-1.0) 268 | clip_sim = calc_clip_sim(CLIP, fake_imgs, CLIP_tokens, device) 269 | clip_cos = clip_cos + clip_sim 270 | fake = norm(fake_imgs) 271 | pred = model(fake)[0] 272 | if pred.shape[2] != 1 or pred.shape[3] != 1: 273 | pred = adaptive_avg_pool2d(pred, output_size=(1, 1)) 274 | # concat pred from multi GPUs 275 | output = list(torch.empty_like(pred) for _ in range(n_gpu)) 276 | dist.all_gather(output, pred) 277 | pred_all = torch.cat(output, dim=0).squeeze(-1).squeeze(-1) 278 | pred_arr[start:end] = pred_all.cpu().data.numpy() 279 | # update loop information 280 | if (n_gpu!=1) and (get_rank() != 0): 281 | None 282 | else: 283 | loop.update(1) 284 | if epoch==-1: 285 | loop.set_description('Evaluating]') 286 | else: 287 | loop.set_description(f'Eval Epoch [{epoch}/{max_epoch}]') 288 | loop.set_postfix() 289 | if (n_gpu!=1) and (get_rank() != 0): 290 | None 291 | else: 292 | loop.close() 293 | # CLIP-score 294 | CLIP_score_gather = list(torch.empty_like(clip_cos) for _ in range(n_gpu)) 295 | dist.all_gather(CLIP_score_gather, clip_cos) 296 | clip_score = torch.cat(CLIP_score_gather, dim=0).mean().item()/(dl_length*times) 297 | # FID 298 | m2 = np.mean(pred_arr, axis=0) 299 | s2 = np.cov(pred_arr, rowvar=False) 300 | fid_value = calculate_frechet_distance(m1, s1, m2, s2) 301 | return fid_value,clip_score 302 | 303 | 304 | def calc_clip_sim(clip, fake, caps_clip, device): 305 | ''' calculate cosine similarity between fake and text features, 306 | ''' 307 | # Calculate features 308 | fake = transf_to_CLIP_input(fake) 309 | fake_features = clip.encode_image(fake) 310 | text_features = clip.encode_text(caps_clip) 311 | text_img_sim = torch.cosine_similarity(fake_features, text_features).mean() 312 | return text_img_sim 313 | 314 | 315 | def sample_one_batch(noise, sent, netG, multi_gpus, epoch, img_save_dir, writer): 316 | if (multi_gpus==True) and (get_rank() != 0): 317 | None 318 | else: 319 | netG.eval() 320 | with torch.no_grad(): 321 | B = noise.size(0) 322 | fixed_results_train = generate_samples(noise[:B//2], sent[:B//2], netG).cpu() 323 | torch.cuda.empty_cache() 324 | fixed_results_test = generate_samples(noise[B//2:], sent[B//2:], netG).cpu() 325 | torch.cuda.empty_cache() 326 | fixed_results = torch.cat((fixed_results_train, fixed_results_test), dim=0) 327 | img_name = 'samples_epoch_%03d.png'%(epoch) 328 | img_save_path = osp.join(img_save_dir, img_name) 329 | vutils.save_image(fixed_results.data, img_save_path, nrow=8, value_range=(-1, 1), normalize=True) 330 | 331 | 332 | def generate_samples(noise, caption, model): 333 | with torch.no_grad(): 334 | fake = model(noise, caption, eval=True) 335 | return fake 336 | 337 | 338 | def predict_loss(predictor, img_feature, text_feature, negtive): 339 | output = predictor(img_feature, text_feature) 340 | err = hinge_loss(output, negtive) 341 | return output,err 342 | 343 | 344 | def hinge_loss(output, negtive): 345 | if negtive==False: 346 | err = torch.mean(F.relu(1. - output)) 347 | else: 348 | err = torch.mean(F.relu(1. + output)) 349 | return err 350 | 351 | 352 | def logit_loss(output, negtive): 353 | batch_size = output.size(0) 354 | real_labels = torch.FloatTensor(batch_size,1).fill_(1).to(output.device) 355 | fake_labels = torch.FloatTensor(batch_size,1).fill_(0).to(output.device) 356 | output = nn.Sigmoid()(output) 357 | if negtive==False: 358 | err = nn.BCELoss()(output, real_labels) 359 | else: 360 | err = nn.BCELoss()(output, fake_labels) 361 | return err 362 | 363 | 364 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 365 | mu1 = np.atleast_1d(mu1) 366 | mu2 = np.atleast_1d(mu2) 367 | 368 | sigma1 = np.atleast_2d(sigma1) 369 | sigma2 = np.atleast_2d(sigma2) 370 | 371 | assert mu1.shape == mu2.shape, \ 372 | 'Training and test mean vectors have different lengths' 373 | assert sigma1.shape == sigma2.shape, \ 374 | 'Training and test covariances have different dimensions' 375 | 376 | diff = mu1 - mu2 377 | ''' 378 | print('&'*20) 379 | print(sigma1)#, sigma1.type()) 380 | print('&'*20) 381 | print(sigma2)#, sigma2.type()) 382 | ''' 383 | # Product might be almost singular 384 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 385 | if not np.isfinite(covmean).all(): 386 | msg = ('fid calculation produces singular product; ' 387 | 'adding %s to diagonal of cov estimates') % eps 388 | print(msg) 389 | offset = np.eye(sigma1.shape[0]) * eps 390 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 391 | 392 | # Numerical error might give slight imaginary component 393 | if np.iscomplexobj(covmean): 394 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 395 | m = np.max(np.abs(covmean.imag)) 396 | raise ValueError('Imaginary component {}'.format(m)) 397 | covmean = covmean.real 398 | 399 | tr_covmean = np.trace(covmean) 400 | 401 | return (diff.dot(diff) + np.trace(sigma1) + 402 | np.trace(sigma2) - 2 * tr_covmean) 403 | --------------------------------------------------------------------------------