├── .DS_Store ├── .gitignore ├── LICENSE ├── README.md ├── assets ├── .DS_Store └── figs │ ├── .DS_Store │ ├── advanced_fig.png │ ├── basic_fig.png │ ├── gan_fig.png │ ├── handson_fig.png │ └── teaser.png ├── lecture ├── 1_Basic_diffusion.pdf ├── 2_Advanced_diffusion.pdf └── 3_HandsOn_diffusion_noans.pdf └── src ├── .DS_Store ├── GALIP ├── .DS_Store ├── GALIP.py ├── dataset │ ├── .DS_Store │ └── coco_2017 │ │ ├── .DS_Store │ │ ├── train │ │ ├── .DS_Store │ │ ├── image │ │ │ ├── 000000000009.jpg │ │ │ ├── 000000000025.jpg │ │ │ ├── 000000000030.jpg │ │ │ ├── 000000000034.jpg │ │ │ ├── 000000000036.jpg │ │ │ └── 000000000042.jpg │ │ └── text │ │ │ ├── 000000000009.txt │ │ │ ├── 000000000025.txt │ │ │ ├── 000000000030.txt │ │ │ ├── 000000000034.txt │ │ │ ├── 000000000036.txt │ │ │ └── 000000000042.txt │ │ └── val │ │ ├── .DS_Store │ │ ├── image │ │ ├── 000000000139.jpg │ │ ├── 000000000285.jpg │ │ ├── 000000000632.jpg │ │ ├── 000000000724.jpg │ │ ├── 000000000776.jpg │ │ └── 000000000785.jpg │ │ └── text │ │ ├── 000000000139.txt │ │ ├── 000000000285.txt │ │ ├── 000000000632.txt │ │ ├── 000000000724.txt │ │ ├── 000000000776.txt │ │ └── 000000000785.txt ├── main.py ├── metric │ └── fid_score.py ├── networks.py ├── ops.py └── utils.py ├── ddpm_ddim ├── .DS_Store ├── dataset │ ├── .DS_Store │ └── cat │ │ ├── flickr_cat_000008.png │ │ ├── flickr_cat_000011.png │ │ ├── flickr_cat_000016.png │ │ ├── flickr_cat_000056.png │ │ └── flickr_cat_000076.png ├── main.py ├── main_template.py ├── modules.py ├── noise.jpg ├── noise_test.py └── utils.py ├── evaluation ├── clipscore.py ├── data_loader.py └── fid.py └── stable_diffusion ├── sd_main.py └── sd_simple_main.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/diffusion-pytorch/3c00400219e799ffe6554c3128743cc029f09507/.DS_Store -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Junho Kim 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # diffusion-pytorch 2 | #### 이화여대 강의자료입니다. 사용시 citation 부탁드립니다. :) 3 | #### Teaching materials from Ewha Womans University. Please cite the link when used. :) 4 | 5 | 6 |
7 | 8 |
9 | 10 | ## Youtube (Korean) 11 | * [The recipe of GANs](https://www.youtube.com/watch?v=vZdEGcLU_8U) 12 | * [The basic diffusion](https://www.youtube.com/watch?v=jaPPALsUZo8) 13 | * [The advanced diffusion](https://www.youtube.com/watch?v=Z8WWriIh1PU) 14 | 15 | ## Author 16 | [Junho Kim](http://bit.ly/jhkim_resume) 17 | 18 | --- 19 | ## Summary of GANs 20 |
21 | 22 |
23 | 24 | --- 25 | 26 | ## Basic diffusion (Theory) 27 | * DDPM, DDIM 28 | * Classifier guidance 29 | * Diffusion + GAN (DDGAN) 30 | 31 |
32 | 33 |
34 | 35 | --- 36 | ## Advanced diffusion (Theory) 37 | * Stable diffusion, GALIP 38 | * Evaluation 39 | * Editing 40 | 41 |
42 | 43 |
44 | 45 | --- 46 | ## Hands-on diffusion (Implementation) 47 | * DDPM, DDIM 48 | * How to use the SD ? 49 | * How to evaluate ? 50 | 51 |
52 | 53 |
54 | 55 | --- 56 | ### Recommended code 57 | * [pytorch & tensorflow code template](https://github.com/taki0112/tf-torch-template) 58 | * [Stylegan2-pytorch](https://github.com/taki0112/stylegan2-pytorch) 59 | * [GALIP-pytorch](https://github.com/taki0112/diffusion-pytorch/tree/main/src/GALIP) 60 | * [DDGAN-tensorflow](https://github.com/taki0112/denoising-diffusion-gan-Tensorflow) 61 | -------------------------------------------------------------------------------- /assets/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/diffusion-pytorch/3c00400219e799ffe6554c3128743cc029f09507/assets/.DS_Store -------------------------------------------------------------------------------- /assets/figs/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/diffusion-pytorch/3c00400219e799ffe6554c3128743cc029f09507/assets/figs/.DS_Store -------------------------------------------------------------------------------- /assets/figs/advanced_fig.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/diffusion-pytorch/3c00400219e799ffe6554c3128743cc029f09507/assets/figs/advanced_fig.png -------------------------------------------------------------------------------- /assets/figs/basic_fig.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/diffusion-pytorch/3c00400219e799ffe6554c3128743cc029f09507/assets/figs/basic_fig.png -------------------------------------------------------------------------------- /assets/figs/gan_fig.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/diffusion-pytorch/3c00400219e799ffe6554c3128743cc029f09507/assets/figs/gan_fig.png -------------------------------------------------------------------------------- /assets/figs/handson_fig.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/diffusion-pytorch/3c00400219e799ffe6554c3128743cc029f09507/assets/figs/handson_fig.png -------------------------------------------------------------------------------- /assets/figs/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/diffusion-pytorch/3c00400219e799ffe6554c3128743cc029f09507/assets/figs/teaser.png -------------------------------------------------------------------------------- /lecture/1_Basic_diffusion.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/diffusion-pytorch/3c00400219e799ffe6554c3128743cc029f09507/lecture/1_Basic_diffusion.pdf -------------------------------------------------------------------------------- /lecture/2_Advanced_diffusion.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/diffusion-pytorch/3c00400219e799ffe6554c3128743cc029f09507/lecture/2_Advanced_diffusion.pdf -------------------------------------------------------------------------------- /lecture/3_HandsOn_diffusion_noans.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/diffusion-pytorch/3c00400219e799ffe6554c3128743cc029f09507/lecture/3_HandsOn_diffusion_noans.pdf -------------------------------------------------------------------------------- /src/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/diffusion-pytorch/3c00400219e799ffe6554c3128743cc029f09507/src/.DS_Store -------------------------------------------------------------------------------- /src/GALIP/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/diffusion-pytorch/3c00400219e799ffe6554c3128743cc029f09507/src/GALIP/.DS_Store -------------------------------------------------------------------------------- /src/GALIP/GALIP.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data 2 | from ops import * 3 | from utils import * 4 | import time 5 | from torch.utils.tensorboard import SummaryWriter 6 | import numpy as np 7 | import torchvision 8 | from functools import partial 9 | import torch.nn.functional as F 10 | 11 | print = partial(print, flush=True) 12 | 13 | from metric.fid_score import InceptionV3, calculate_fid_t2i 14 | from networks import * 15 | 16 | def run_fn(rank, args, world_size): 17 | device = torch.device('cuda', rank) 18 | torch.backends.cudnn.benchmark = True 19 | 20 | model = GALIP(args, world_size) 21 | model.build_model(rank, device) 22 | model.train_model(rank, device) 23 | 24 | class GALIP(): 25 | def __init__(self, args, NUM_GPUS): 26 | super(GALIP, self).__init__() 27 | 28 | """ Model """ 29 | self.model_name = 'GALIP' 30 | self.phase = args['phase'] 31 | self.NUM_GPUS = NUM_GPUS 32 | 33 | 34 | """ Training parameters """ 35 | self.img_size = args['img_size'] 36 | self.batch_size = args['batch_size'] 37 | self.global_batch_size = self.batch_size * self.NUM_GPUS 38 | self.epoch = args['epoch'] 39 | if self.epoch != 0: 40 | self.iteration = None 41 | else: 42 | self.iteration = args['iteration'] 43 | self.mixed_flag = args['mixed_flag'] 44 | self.growth_interval = 2000 45 | self.scaler_min = 64 46 | 47 | """ Network parameters """ 48 | self.style_dim = 100 49 | self.g_lr = args['g_lr'] 50 | self.d_lr = args['d_lr'] 51 | 52 | """ Print parameters """ 53 | self.print_freq = args['print_freq'] 54 | self.save_freq = args['save_freq'] 55 | self.log_template = 'step [{}/{}]: elapsed: {:.2f}s, BEST_FID: {:.2f}' 56 | 57 | """ Dataset Path """ 58 | self.dataset_name = args['dataset'] 59 | self.val_dataset_name = self.dataset_name + '_val' 60 | dataset_path = './dataset' 61 | self.dataset_path = os.path.join(dataset_path, self.dataset_name) 62 | self.val_dataset_path = os.path.join(dataset_path, self.val_dataset_name) 63 | 64 | """ Directory """ 65 | self.checkpoint_dir = args['checkpoint_dir'] 66 | self.result_dir = args['result_dir'] 67 | self.log_dir = args['log_dir'] 68 | self.sample_dir = args['sample_dir'] 69 | 70 | self.sample_dir = os.path.join(self.sample_dir, self.model_dir) 71 | check_folder(self.sample_dir) 72 | self.checkpoint_dir = os.path.join(self.checkpoint_dir, self.model_dir) 73 | check_folder(self.checkpoint_dir) 74 | self.log_dir = os.path.join(self.log_dir, self.model_dir) 75 | check_folder(self.log_dir) 76 | 77 | ################################################################################## 78 | # Model 79 | ################################################################################## 80 | def build_model(self, rank, device): 81 | """ Init process """ 82 | build_init_procss(rank, world_size=self.NUM_GPUS, device=device) 83 | 84 | """ Dataset Load """ 85 | dataset = ImageTextDataset(dataset_path=self.dataset_path, img_size=self.img_size) 86 | self.dataset_num = dataset.__len__() 87 | self.iteration = self.epoch * self.dataset_num // self.global_batch_size 88 | loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, num_workers=4, 89 | sampler=distributed_sampler(dataset, rank=rank, num_replicas=self.NUM_GPUS, shuffle=True), 90 | drop_last=True, pin_memory=True) 91 | self.dataset_iter = infinite_iterator(loader) 92 | 93 | """ For FID """ 94 | self.fid_dataset = ImageTextDataset(dataset_path=self.val_dataset_path, img_size=299, imagenet_normalization=True) 95 | self.fid_loader = torch.utils.data.DataLoader(self.fid_dataset, batch_size=5, num_workers=4, 96 | sampler=distributed_sampler(dataset, rank=rank, num_replicas=self.NUM_GPUS, shuffle=False), 97 | drop_last=False, pin_memory=True) 98 | self.inception = InceptionV3(mixed_precision=self.mixed_flag).to(device) 99 | requires_grad(self.inception, False) 100 | 101 | """ Pretrain Model Load """ 102 | self.clip = clip.load('ViT-B/32')[0].eval().to(device) 103 | self.clip_img = CLIP_IMG_ENCODER(self.clip).to(device) 104 | self.clip_text = CLIP_TXT_ENCODER(self.clip).to(device) 105 | 106 | requires_grad(self.clip_img, False) 107 | requires_grad(self.clip_text, False) 108 | 109 | self.clip_img.eval() 110 | self.clip_text.eval() 111 | 112 | 113 | """ Network """ 114 | if self.mixed_flag: 115 | self.scaler_G = torch.cuda.amp.GradScaler(growth_interval=self.growth_interval) 116 | self.scaler_D = torch.cuda.amp.GradScaler(growth_interval=self.growth_interval) 117 | else: 118 | self.scaler_G = None 119 | self.scaler_D = None 120 | self.generator = NetG(imsize=self.img_size, CLIP=self.clip, nz=self.style_dim, mixed_precision=self.mixed_flag).to(device) 121 | self.discriminator = NetD(imsize=self.img_size, mixed_precision=self.mixed_flag).to(device) 122 | self.predictor = NetC(mixed_precision=self.mixed_flag).to(device) 123 | 124 | 125 | """ Optimizer """ 126 | D_params = list(self.discriminator.parameters()) + list(self.predictor.parameters()) 127 | self.g_optimizer = torch.optim.Adam(self.generator.parameters(), lr=self.g_lr, betas=(0.0, 0.9), eps=1e-08) 128 | self.d_optimizer = torch.optim.Adam(D_params, lr=self.d_lr, betas=(0.0, 0.9), eps=1e-08) 129 | 130 | """ Distributed Learning """ 131 | self.generator = dataparallel_and_sync(self.generator, rank) 132 | self.discriminator = dataparallel_and_sync(self.discriminator, rank) 133 | self.predictor = dataparallel_and_sync(self.predictor, rank) 134 | 135 | """ Checkpoint """ 136 | self.ckpt_dict= { 137 | 'generator': self.generator.state_dict(), 138 | 'discriminator': self.discriminator.state_dict(), 139 | 'predictor' : self.predictor.state_dict(), 140 | 'g_optimizer': self.g_optimizer.state_dict(), 141 | 'd_optimizer':self.d_optimizer.state_dict() 142 | }, 143 | 144 | latest_ckpt_name, start_iter = find_latest_ckpt(self.checkpoint_dir) 145 | if latest_ckpt_name is not None: 146 | if rank == 0: 147 | print('Latest checkpoint restored!! ', latest_ckpt_name) 148 | print('start iteration : ', start_iter) 149 | self.start_iteration = start_iter 150 | 151 | latest_ckpt = os.path.join(self.checkpoint_dir, latest_ckpt_name) 152 | ckpt = torch.load(latest_ckpt, map_location=device) 153 | 154 | self.generator.load_state_dict(ckpt["generator"]) 155 | self.discriminator.load_state_dict(ckpt['discriminator']) 156 | self.predictor.load_state_dict(ckpt['predictor']) 157 | self.g_optimizer.load_state_dict(ckpt["g_optimizer"]) 158 | self.d_optimizer.load_state_dict(ckpt["d_optimizer"]) 159 | 160 | else: 161 | if rank == 0: 162 | print('Not restoring from saved checkpoint') 163 | self.start_iteration = 0 164 | 165 | def g_train_step(self, real_img, tokens, device=torch.device('cuda')): 166 | self.generator.train() 167 | self.discriminator.train() 168 | self.predictor.train() 169 | 170 | # step 0: pre-process 171 | with torch.cuda.amp.autocast() if self.mixed_flag else dummy_context_mgr() as mpc: 172 | with torch.no_grad(): 173 | sent_emb, word_emb = self.clip_text(tokens) # [bs, 512], [bs, 77, 512] 174 | word_emb = word_emb.detach() 175 | sent_emb = sent_emb.detach() 176 | 177 | # synthesize fake images 178 | noise = torch.randn([self.batch_size, self.style_dim]).to(device) 179 | fake_img = self.generator(noise, sent_emb) 180 | CLIP_fake, fake_emb = self.clip_img(fake_img) 181 | 182 | # loss 183 | fake_feats = self.discriminator(CLIP_fake) 184 | output = self.predictor(fake_feats, sent_emb) 185 | text_img_sim = torch.cosine_similarity(fake_emb, sent_emb).mean() 186 | loss = -output.mean() - 4.0 * text_img_sim 187 | 188 | apply_gradients(loss, self.g_optimizer, self.mixed_flag, self.scaler_G, self.scaler_min) 189 | 190 | return loss, sent_emb 191 | 192 | def d_train_step(self, real_img, tokens, device=torch.device('cuda')): 193 | self.generator.train() 194 | self.discriminator.train() 195 | self.predictor.train() 196 | 197 | # step 0: pre-process 198 | with torch.cuda.amp.autocast() if self.mixed_flag else dummy_context_mgr() as mpc: 199 | with torch.no_grad(): 200 | sent_emb, word_emb = self.clip_text(tokens) # [bs, 512], [bs, 77, 512] 201 | word_emb = word_emb.detach() 202 | sent_emb = sent_emb.detach() 203 | 204 | 205 | # loss 206 | real_img = real_img.requires_grad_() 207 | sent_emb = sent_emb.requires_grad_() 208 | word_emb = word_emb.requires_grad_() 209 | 210 | # predict real 211 | CLIP_real, real_emb = self.clip_img(real_img) # [bs, 3, 768, 7, 7], [bs, 512] 212 | real_feats = self.discriminator(CLIP_real) # [bs, 512, 7, 7] 213 | pred_real, errD_real = predict_loss(self.predictor, real_feats, sent_emb, negtive=False) 214 | 215 | # predict mismatch 216 | mis_sent_emb = torch.cat((sent_emb[1:], sent_emb[0:1]), dim=0).detach() 217 | _, errD_mis = predict_loss(self.predictor, real_feats, mis_sent_emb, negtive=True) 218 | 219 | # synthesize fake images 220 | noise = torch.randn([self.batch_size, self.style_dim]).to(device) 221 | fake_img = self.generator(noise, sent_emb) 222 | CLIP_fake, fake_emb = self.clip_img(fake_img) 223 | fake_feats = self.discriminator(CLIP_fake.detach()) 224 | _, errD_fake = predict_loss(self.predictor, fake_feats, sent_emb, negtive=True) 225 | 226 | if self.mixed_flag: 227 | errD_MAGP = MA_GP_MP(CLIP_real, sent_emb, pred_real, self.scaler_D) 228 | else: 229 | errD_MAGP = MA_GP_FP32(CLIP_real, sent_emb, pred_real) 230 | 231 | with torch.cuda.amp.autocast() if self.mixed_flag else dummy_context_mgr() as mpc: 232 | loss = errD_real + (errD_fake + errD_mis) / 2.0 + errD_MAGP 233 | 234 | apply_gradients(loss, self.d_optimizer, self.mixed_flag, self.scaler_D, self.scaler_min) 235 | 236 | return loss 237 | 238 | def train_model(self, rank, device): 239 | start_time = time.time() 240 | fid_start_time = time.time() 241 | 242 | # setup tensorboards 243 | train_summary_writer = SummaryWriter(self.log_dir) 244 | 245 | # start training 246 | if rank == 0: 247 | print() 248 | print(self.dataset_path) 249 | print("Dataset number : ", self.dataset_num) 250 | print("GPUs : ", self.NUM_GPUS) 251 | print("Each batch size : ", self.batch_size) 252 | print("Global batch size : ", self.global_batch_size) 253 | print("Target image size : ", self.img_size) 254 | print("Print frequency : ", self.print_freq) 255 | print("Save frequency : ", self.save_freq) 256 | print("PyTorch Version :", torch.__version__) 257 | print('max_steps: {}'.format(self.iteration)) 258 | print() 259 | losses = {'g_loss': 0.0, 'd_loss': 0.0} 260 | 261 | fid_dict = {'metric/fid': 0.0, 'metric/best_fid': 0.0, 'metric/best_fid_iter': 0} 262 | fid = 0 263 | best_fid = 1000 264 | best_fid_iter = 0 265 | 266 | for idx in range(self.start_iteration, self.iteration): 267 | iter_start_time = time.time() 268 | 269 | image, tokens, text = next(self.dataset_iter) 270 | image = image.to(device) 271 | tokens = tokens.to(device) 272 | # text = text.to(device) 273 | 274 | if idx == 0: 275 | if rank == 0: 276 | print("count params") 277 | g_params = count_parameters(self.generator) 278 | d_params = count_parameters(self.discriminator) + count_parameters(self.predictor) 279 | g_B, g_M = convert_to_billion_and_million(g_params) 280 | d_B, d_M = convert_to_billion_and_million(d_params) 281 | 282 | t_B = g_B + d_B 283 | t_M = g_M + d_M 284 | 285 | print("G network parameters : {}B, {}M".format(g_B, g_M)) 286 | print("D network parameters : {}B, {}M".format(d_B, d_M)) 287 | print("Total network parameters : {}B, {}M".format(t_B, t_M)) 288 | print() 289 | 290 | loss = self.d_train_step(image, tokens, device=device) 291 | 292 | losses['d_loss'] = loss 293 | 294 | loss, text_embed = self.g_train_step(image, tokens, device) 295 | losses['g_loss'] = loss 296 | 297 | losses = reduce_loss_dict(losses) 298 | losses = dict_to_numpy(losses, python_value=True) 299 | 300 | if np.mod(idx, self.print_freq) == 0 or idx == self.iteration - 1 : 301 | if rank == 0: 302 | print("calculate fid ...") 303 | fid_start_time = time.time() 304 | 305 | fid = calculate_fid_t2i(self.fid_loader, self.generator, self.inception, self.clip_text, self.val_dataset_name, 306 | device=device, latent_dim=self.style_dim) 307 | 308 | if rank == 0: 309 | fid_end_time = time.time() 310 | fid_elapsed = fid_end_time - fid_start_time 311 | print("calculate fid finish: {:.2f}s".format(fid_elapsed)) 312 | if fid < best_fid: 313 | print("BEST FID UPDATED") 314 | best_fid = fid 315 | best_fid_iter = idx 316 | self.torch_save(idx, fid) 317 | 318 | fid_dict['metric/best_fid'] = best_fid 319 | fid_dict['metric/best_fid_iter'] = best_fid_iter 320 | fid_dict['metric/fid'] = fid 321 | 322 | 323 | if rank == 0: 324 | # save to tensorboard 325 | 326 | for k, v in losses.items(): 327 | train_summary_writer.add_scalar(k, v, global_step=idx) 328 | 329 | if np.mod(idx, self.print_freq) == 0 or idx == self.iteration - 1: 330 | train_summary_writer.add_scalar('fid', fid, global_step=idx) 331 | 332 | if np.mod(idx + 1, self.print_freq) == 0: 333 | with torch.no_grad(): 334 | batch_size = text_embed.shape[0] 335 | 336 | noise = torch.randn([batch_size, self.style_dim]).to(device) 337 | self.generator.eval() 338 | fake_img = self.generator(noise, text_embed) 339 | fake_img = torch.clamp(fake_img, -1.0, 1.0) 340 | 341 | partial_size = int(batch_size ** 0.5) 342 | 343 | # resize 344 | fake_img = F.interpolate(fake_img, size=256, mode='bicubic', align_corners=True) 345 | torchvision.utils.save_image(fake_img, './{}/fake_{:06d}.png'.format(self.sample_dir, idx + 1), 346 | nrow=partial_size, 347 | normalize=True, range=(-1, 1)) 348 | text_path = './{}/fake_{:06d}.txt'.format(self.sample_dir, idx+1) 349 | with open(text_path, 'w') as f: 350 | f.write('\n'.join(text)) 351 | # normalize = set to the range (0, 1) by range(min, max) 352 | 353 | elapsed = time.time() - iter_start_time 354 | print(self.log_template.format(idx, self.iteration, elapsed, best_fid)) 355 | 356 | dist.barrier() 357 | 358 | if rank == 0: 359 | # save model for final step 360 | self.torch_save(self.iteration, fid) 361 | 362 | print("LAST FID: ", fid) 363 | print("BEST FID: {}, {}".format(best_fid, best_fid_iter)) 364 | print("Total train time: %4.4f" % (time.time() - start_time)) 365 | 366 | dist.barrier() 367 | 368 | def torch_save(self, idx, fid=0): 369 | fid_int = int(fid) 370 | torch.save( 371 | self.ckpt_dict, 372 | os.path.join(self.checkpoint_dir, 'iter_{}_fid_{}.pt'.format(idx, fid_int)) 373 | ) 374 | 375 | @property 376 | def model_dir(self): 377 | return "{}_{}_{}_bs{}_{}GPUs_Mixed{}".format(self.model_name, self.dataset_name, self.img_size, self.batch_size, self.NUM_GPUS, self.mixed_flag) -------------------------------------------------------------------------------- /src/GALIP/dataset/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/diffusion-pytorch/3c00400219e799ffe6554c3128743cc029f09507/src/GALIP/dataset/.DS_Store -------------------------------------------------------------------------------- /src/GALIP/dataset/coco_2017/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/diffusion-pytorch/3c00400219e799ffe6554c3128743cc029f09507/src/GALIP/dataset/coco_2017/.DS_Store -------------------------------------------------------------------------------- /src/GALIP/dataset/coco_2017/train/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/diffusion-pytorch/3c00400219e799ffe6554c3128743cc029f09507/src/GALIP/dataset/coco_2017/train/.DS_Store -------------------------------------------------------------------------------- /src/GALIP/dataset/coco_2017/train/image/000000000009.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/diffusion-pytorch/3c00400219e799ffe6554c3128743cc029f09507/src/GALIP/dataset/coco_2017/train/image/000000000009.jpg -------------------------------------------------------------------------------- /src/GALIP/dataset/coco_2017/train/image/000000000025.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/diffusion-pytorch/3c00400219e799ffe6554c3128743cc029f09507/src/GALIP/dataset/coco_2017/train/image/000000000025.jpg -------------------------------------------------------------------------------- /src/GALIP/dataset/coco_2017/train/image/000000000030.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/diffusion-pytorch/3c00400219e799ffe6554c3128743cc029f09507/src/GALIP/dataset/coco_2017/train/image/000000000030.jpg -------------------------------------------------------------------------------- /src/GALIP/dataset/coco_2017/train/image/000000000034.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/diffusion-pytorch/3c00400219e799ffe6554c3128743cc029f09507/src/GALIP/dataset/coco_2017/train/image/000000000034.jpg -------------------------------------------------------------------------------- /src/GALIP/dataset/coco_2017/train/image/000000000036.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/diffusion-pytorch/3c00400219e799ffe6554c3128743cc029f09507/src/GALIP/dataset/coco_2017/train/image/000000000036.jpg -------------------------------------------------------------------------------- /src/GALIP/dataset/coco_2017/train/image/000000000042.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/diffusion-pytorch/3c00400219e799ffe6554c3128743cc029f09507/src/GALIP/dataset/coco_2017/train/image/000000000042.jpg -------------------------------------------------------------------------------- /src/GALIP/dataset/coco_2017/train/text/000000000009.txt: -------------------------------------------------------------------------------- 1 | Closeup of bins of food that include broccoli and bread. 2 | A meal is presented in brightly colored plastic trays. 3 | there are containers filled with different kinds of foods 4 | Colorful dishes holding meat, vegetables, fruit, and bread. 5 | A bunch of trays that have different food. 6 | -------------------------------------------------------------------------------- /src/GALIP/dataset/coco_2017/train/text/000000000025.txt: -------------------------------------------------------------------------------- 1 | A giraffe eating food from the top of the tree. 2 | A giraffe standing up nearby a tree 3 | A giraffe mother with its baby in the forest. 4 | Two giraffes standing in a tree filled area. 5 | A giraffe standing next to a forest filled with trees. 6 | -------------------------------------------------------------------------------- /src/GALIP/dataset/coco_2017/train/text/000000000030.txt: -------------------------------------------------------------------------------- 1 | A flower vase is sitting on a porch stand. 2 | White vase with different colored flowers sitting inside of it. 3 | a white vase with many flowers on a stage 4 | A white vase filled with different colored flowers. 5 | A vase with red and white flowers outside on a sunny day. 6 | -------------------------------------------------------------------------------- /src/GALIP/dataset/coco_2017/train/text/000000000034.txt: -------------------------------------------------------------------------------- 1 | A zebra grazing on lush green grass in a field. 2 | Zebra reaching its head down to ground where grass is. 3 | The zebra is eating grass in the sun. 4 | A lone zebra grazing in some green grass. 5 | a Zebra grazing on grass in a green open field. 6 | -------------------------------------------------------------------------------- /src/GALIP/dataset/coco_2017/train/text/000000000036.txt: -------------------------------------------------------------------------------- 1 | Woman in swim suit holding parasol on sunny day. 2 | A woman posing for the camera, holding a pink, open umbrella and wearing a bright, floral, ruched bathing suit, by a life guard stand with lake, green trees, and a blue sky with a few clouds behind. 3 | A woman in a floral swimsuit holds a pink umbrella. 4 | A woman with an umbrella near the sea 5 | A girl in a bathing suit with a pink umbrella. 6 | -------------------------------------------------------------------------------- /src/GALIP/dataset/coco_2017/train/text/000000000042.txt: -------------------------------------------------------------------------------- 1 | This wire metal rack holds several pairs of shoes and sandals 2 | A dog sleeping on a show rack in the shoes. 3 | Various slides and other footwear rest in a metal basket outdoors. 4 | A small dog is curled up on top of the shoes 5 | a shoe rack with some shoes and a dog sleeping on them 6 | -------------------------------------------------------------------------------- /src/GALIP/dataset/coco_2017/val/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/diffusion-pytorch/3c00400219e799ffe6554c3128743cc029f09507/src/GALIP/dataset/coco_2017/val/.DS_Store -------------------------------------------------------------------------------- /src/GALIP/dataset/coco_2017/val/image/000000000139.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/diffusion-pytorch/3c00400219e799ffe6554c3128743cc029f09507/src/GALIP/dataset/coco_2017/val/image/000000000139.jpg -------------------------------------------------------------------------------- /src/GALIP/dataset/coco_2017/val/image/000000000285.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/diffusion-pytorch/3c00400219e799ffe6554c3128743cc029f09507/src/GALIP/dataset/coco_2017/val/image/000000000285.jpg -------------------------------------------------------------------------------- /src/GALIP/dataset/coco_2017/val/image/000000000632.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/diffusion-pytorch/3c00400219e799ffe6554c3128743cc029f09507/src/GALIP/dataset/coco_2017/val/image/000000000632.jpg -------------------------------------------------------------------------------- /src/GALIP/dataset/coco_2017/val/image/000000000724.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/diffusion-pytorch/3c00400219e799ffe6554c3128743cc029f09507/src/GALIP/dataset/coco_2017/val/image/000000000724.jpg -------------------------------------------------------------------------------- /src/GALIP/dataset/coco_2017/val/image/000000000776.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/diffusion-pytorch/3c00400219e799ffe6554c3128743cc029f09507/src/GALIP/dataset/coco_2017/val/image/000000000776.jpg -------------------------------------------------------------------------------- /src/GALIP/dataset/coco_2017/val/image/000000000785.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/diffusion-pytorch/3c00400219e799ffe6554c3128743cc029f09507/src/GALIP/dataset/coco_2017/val/image/000000000785.jpg -------------------------------------------------------------------------------- /src/GALIP/dataset/coco_2017/val/text/000000000139.txt: -------------------------------------------------------------------------------- 1 | A woman stands in the dining area at the table. 2 | A room with chairs, a table, and a woman in it. 3 | A woman standing in a kitchen by a window 4 | A person standing at a table in a room. 5 | A living area with a television and a table 6 | -------------------------------------------------------------------------------- /src/GALIP/dataset/coco_2017/val/text/000000000285.txt: -------------------------------------------------------------------------------- 1 | A big burly grizzly bear is show with grass in the background. 2 | The large brown bear has a black nose. 3 | Closeup of a brown bear sitting in a grassy area. 4 | A large bear that is sitting on grass. 5 | A close up picture of a brown bear's face. 6 | -------------------------------------------------------------------------------- /src/GALIP/dataset/coco_2017/val/text/000000000632.txt: -------------------------------------------------------------------------------- 1 | Bedroom scene with a bookcase, blue comforter and window. 2 | A bedroom with a bookshelf full of books. 3 | This room has a bed with blue sheets and a large bookcase 4 | A bed and a mirror in a small room. 5 | a bed room with a neatly made bed a window and a book shelf 6 | -------------------------------------------------------------------------------- /src/GALIP/dataset/coco_2017/val/text/000000000724.txt: -------------------------------------------------------------------------------- 1 | A stop sign is mounted upside-down on it's post. 2 | A stop sign that is hanging upside down. 3 | An upside down stop sign by the road. 4 | a stop sign put upside down on a metal pole 5 | A stop sign installed upside down on a street corner 6 | -------------------------------------------------------------------------------- /src/GALIP/dataset/coco_2017/val/text/000000000776.txt: -------------------------------------------------------------------------------- 1 | Three teddy bears, each a different color, snuggling together. 2 | Three stuffed animals are sitting on a bed. 3 | three teddy bears giving each other a hug 4 | A group of three stuffed animal teddy bears. 5 | Three stuffed bears hugging and sitting on a blue pillow 6 | -------------------------------------------------------------------------------- /src/GALIP/dataset/coco_2017/val/text/000000000785.txt: -------------------------------------------------------------------------------- 1 | A woman posing for the camera standing on skis. 2 | a woman standing on skiis while posing for the camera 3 | A woman in a red jacket skiing down a slope 4 | A young woman is skiing down the mountain slope. 5 | a person on skis makes her way through the snow 6 | -------------------------------------------------------------------------------- /src/GALIP/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from utils import * 3 | from GALIP import run_fn 4 | 5 | """ 6 | count params 7 | G network parameters : 51,830,211 8 | D network parameters : 30,806,021 9 | Total network parameters : 82,636,232 10 | """ 11 | 12 | def parse_args(): 13 | desc = "Pytorch implementation of GALIP" 14 | parser = argparse.ArgumentParser(description=desc) 15 | parser.add_argument('--phase', type=str, default='train', help='[train, test]') 16 | parser.add_argument('--dataset', type=str, default='coco_2017', help='dataset_name') 17 | # celeba_hq_text 18 | # coco_2017 19 | parser.add_argument('--epoch', type=int, default=3000, help='The total epoch') 20 | parser.add_argument('--iteration', type=int, default=1000000, help='The total iterations') 21 | parser.add_argument('--img_size', type=int, default=256, help='The size of image') 22 | parser.add_argument('--batch_size', type=int, default=64, help='batch sizes for each gpus') 23 | parser.add_argument('--mixed_flag', type=str2bool, default=True, help='Mixed Precision Flag') 24 | # single = 16 25 | 26 | # StyleGAN paraeter 27 | parser.add_argument("--g_lr", type=float, default=0.0001, help="g learning rate") 28 | parser.add_argument("--d_lr", type=float, default=0.0004, help="d learning rate") 29 | 30 | parser.add_argument('--print_freq', type=int, default=5000, help='The number of image_print_freq') 31 | parser.add_argument('--save_freq', type=int, default=50000, help='The number of ckpt_save_freq') 32 | 33 | parser.add_argument('--checkpoint_dir', type=str, default='checkpoint', 34 | help='Directory name to save the checkpoints') 35 | parser.add_argument('--result_dir', type=str, default='results', 36 | help='Directory name to save the generated images') 37 | parser.add_argument('--log_dir', type=str, default='logs', 38 | help='Directory name to save training logs') 39 | parser.add_argument('--sample_dir', type=str, default='samples', 40 | help='Directory name to save the samples_prev on training') 41 | 42 | return check_args(parser.parse_args()) 43 | 44 | 45 | """checking arguments""" 46 | def check_args(args): 47 | # --checkpoint_dir 48 | check_folder(args.checkpoint_dir) 49 | 50 | # --result_dir 51 | check_folder(args.result_dir) 52 | 53 | # --result_dir 54 | check_folder(args.log_dir) 55 | 56 | # --sample_dir 57 | check_folder(args.sample_dir) 58 | 59 | # --batch_size 60 | try: 61 | assert args.batch_size >= 1 62 | except: 63 | print('batch size must be larger than or equal to one', flush=True) 64 | 65 | return args 66 | 67 | """main""" 68 | def main(): 69 | 70 | args = vars(parse_args()) 71 | 72 | # run 73 | multi_gpu_run(ddp_fn=run_fn, args=args) 74 | 75 | if __name__ == '__main__': 76 | main() -------------------------------------------------------------------------------- /src/GALIP/metric/fid_score.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | import numpy as np 5 | from torchvision import models 6 | import torch.distributed as dist 7 | import math 8 | from tqdm import tqdm 9 | from torchvision import transforms 10 | from scipy import linalg 11 | import pickle, os 12 | from torch.nn.functional import adaptive_avg_pool2d 13 | 14 | class dummy_context_mgr(): 15 | def __enter__(self): 16 | return None 17 | 18 | def __exit__(self, exc_type, exc_value, traceback): 19 | return False 20 | 21 | class GatherLayer(torch.autograd.Function): 22 | @staticmethod 23 | def forward(ctx, input): 24 | ctx.save_for_backward(input) 25 | output = [torch.zeros_like(input) for _ in range(dist.get_world_size())] 26 | dist.all_gather(output, input) 27 | return tuple(output) 28 | 29 | @staticmethod 30 | def backward(ctx, *grads): 31 | input, = ctx.saved_tensors 32 | grad_out = torch.zeros_like(input) 33 | grad_out[:] = grads[dist.get_rank()] 34 | return grad_out 35 | 36 | 37 | class InceptionV3_(nn.Module): 38 | def __init__(self): 39 | super().__init__() 40 | inception = models.inception_v3(weights='DEFAULT') 41 | # pretrained=True -> weights=Inception_V3_Weights.IMAGENET1K_V1 42 | # weights='DEFAULT' or weights='IMAGENET1K_V1' 43 | self.block1 = nn.Sequential( 44 | inception.Conv2d_1a_3x3, inception.Conv2d_2a_3x3, 45 | inception.Conv2d_2b_3x3, 46 | nn.MaxPool2d(kernel_size=3, stride=2)) 47 | self.block2 = nn.Sequential( 48 | inception.Conv2d_3b_1x1, inception.Conv2d_4a_3x3, 49 | nn.MaxPool2d(kernel_size=3, stride=2)) 50 | self.block3 = nn.Sequential( 51 | inception.Mixed_5b, inception.Mixed_5c, 52 | inception.Mixed_5d, inception.Mixed_6a, 53 | inception.Mixed_6b, inception.Mixed_6c, 54 | inception.Mixed_6d, inception.Mixed_6e) 55 | self.block4 = nn.Sequential( 56 | inception.Mixed_7a, inception.Mixed_7b, 57 | inception.Mixed_7c, 58 | nn.AdaptiveAvgPool2d(output_size=(1, 1))) 59 | 60 | def forward(self, x): 61 | x = self.block1(x) 62 | x = self.block2(x) 63 | x = self.block3(x) 64 | x = self.block4(x) 65 | return x.view(x.size(0), -1) 66 | 67 | class InceptionV3(nn.Module): 68 | """Pretrained InceptionV3 network returning feature maps""" 69 | 70 | # Index of default block of inception to return, 71 | # corresponds to output of final average pooling 72 | DEFAULT_BLOCK_INDEX = 3 73 | 74 | # Maps feature dimensionality to their output blocks indices 75 | BLOCK_INDEX_BY_DIM = { 76 | 64: 0, # First max pooling features 77 | 192: 1, # Second max pooling featurs 78 | 768: 2, # Pre-aux classifier features 79 | 2048: 3 # Final average pooling features 80 | } 81 | 82 | def __init__(self, 83 | mixed_precision=False, 84 | output_blocks=[DEFAULT_BLOCK_INDEX], 85 | resize_input=True, 86 | normalize_input=True, 87 | requires_grad=False): 88 | """Build pretrained InceptionV3 89 | 90 | Parameters 91 | ---------- 92 | output_blocks : list of int 93 | Indices of blocks to return features of. Possible values are: 94 | - 0: corresponds to output of first max pooling 95 | - 1: corresponds to output of second max pooling 96 | - 2: corresponds to output which is fed to aux classifier 97 | - 3: corresponds to output of final average pooling 98 | resize_input : bool 99 | If true, bilinearly resizes input to width and height 299 before 100 | feeding input to model. As the network without fully connected 101 | layers is fully convolutional, it should be able to handle inputs 102 | of arbitrary size, so resizing might not be strictly needed 103 | normalize_input : bool 104 | If true, normalizes the input to the statistics the pretrained 105 | Inception network expects 106 | requires_grad : bool 107 | If true, parameters of the model require gradient. Possibly useful 108 | for finetuning the network 109 | """ 110 | super(InceptionV3, self).__init__() 111 | 112 | self.resize_input = resize_input 113 | self.normalize_input = normalize_input 114 | self.output_blocks = sorted(output_blocks) 115 | self.last_needed_block = max(output_blocks) 116 | 117 | assert self.last_needed_block <= 3, \ 118 | 'Last possible output block index is 3' 119 | 120 | self.blocks = nn.ModuleList() 121 | 122 | inception = models.inception_v3(pretrained=True) 123 | 124 | # Block 0: input to maxpool1 125 | block0 = [ 126 | inception.Conv2d_1a_3x3, 127 | inception.Conv2d_2a_3x3, 128 | inception.Conv2d_2b_3x3, 129 | nn.MaxPool2d(kernel_size=3, stride=2) 130 | ] 131 | self.blocks.append(nn.Sequential(*block0)) 132 | 133 | # Block 1: maxpool1 to maxpool2 134 | if self.last_needed_block >= 1: 135 | block1 = [ 136 | inception.Conv2d_3b_1x1, 137 | inception.Conv2d_4a_3x3, 138 | nn.MaxPool2d(kernel_size=3, stride=2) 139 | ] 140 | self.blocks.append(nn.Sequential(*block1)) 141 | 142 | # Block 2: maxpool2 to aux classifier 143 | if self.last_needed_block >= 2: 144 | block2 = [ 145 | inception.Mixed_5b, 146 | inception.Mixed_5c, 147 | inception.Mixed_5d, 148 | inception.Mixed_6a, 149 | inception.Mixed_6b, 150 | inception.Mixed_6c, 151 | inception.Mixed_6d, 152 | inception.Mixed_6e, 153 | ] 154 | self.blocks.append(nn.Sequential(*block2)) 155 | 156 | # Block 3: aux classifier to final avgpool 157 | if self.last_needed_block >= 3: 158 | block3 = [ 159 | inception.Mixed_7a, 160 | inception.Mixed_7b, 161 | inception.Mixed_7c, 162 | nn.AdaptiveAvgPool2d(output_size=(1, 1)) 163 | ] 164 | self.blocks.append(nn.Sequential(*block3)) 165 | 166 | for param in self.parameters(): 167 | param.requires_grad = requires_grad 168 | 169 | def forward(self, inp): 170 | """Get Inception feature maps 171 | 172 | Parameters 173 | ---------- 174 | inp : torch.autograd.Variable 175 | Input tensor of shape Bx3xHxW. Values are expected to be in 176 | range (0, 1) 177 | 178 | Returns 179 | ------- 180 | List of torch.autograd.Variable, corresponding to the selected output 181 | block, sorted ascending by index 182 | """ 183 | outp = [] 184 | x = inp 185 | 186 | if self.resize_input: 187 | x = F.upsample(x, size=(299, 299), mode='bilinear', align_corners=True) 188 | 189 | if self.normalize_input: 190 | x = x.clone() 191 | x[:, 0] = x[:, 0] * (0.229 / 0.5) + (0.485 - 0.5) / 0.5 192 | x[:, 1] = x[:, 1] * (0.224 / 0.5) + (0.456 - 0.5) / 0.5 193 | x[:, 2] = x[:, 2] * (0.225 / 0.5) + (0.406 - 0.5) / 0.5 194 | 195 | for idx, block in enumerate(self.blocks): 196 | x = block(x) 197 | if idx in self.output_blocks: 198 | outp.append(x) 199 | 200 | if idx == self.last_needed_block: 201 | break 202 | 203 | return outp 204 | 205 | def extract_real_feature(data_loader, inception, device, t2i=False): 206 | feats = [] 207 | 208 | if t2i: 209 | for img, tokens, txt in tqdm(data_loader): 210 | img = img.to(device) 211 | feat = inception(img) 212 | 213 | feats.append(feat) 214 | else: 215 | for img in tqdm(data_loader): 216 | img = img.to(device) 217 | feat = inception(img) 218 | 219 | feats.append(feat) 220 | 221 | feats = gather_feats(feats) 222 | 223 | return feats 224 | 225 | def normalize_fake_img(imgs): 226 | """ 227 | mean = [0.485, 0.456, 0.406] 228 | std = [0.229, 0.224, 0.225] 229 | 230 | imgs = (imgs + 1) / 2 # -1 ~ 1 to 0~1 231 | imgs = torch.clamp(imgs, 0, 1) 232 | imgs = transforms.Resize(size=[299, 299], interpolation=transforms.InterpolationMode.BICUBIC, antialias=True)(imgs), 233 | imgs = transforms.Normalize(mean=mean, std=std)(imgs) 234 | """ 235 | 236 | norm = transforms.Compose([ 237 | transforms.Normalize((-1, -1, -1), (2, 2, 2)), # (x - (-1)) / 2 = (x + 1) / 2 238 | transforms.Resize((299, 299)), 239 | ]) 240 | 241 | x = norm(imgs) 242 | 243 | return x 244 | 245 | def gather_feats(feats): 246 | feats = torch.cat(feats, dim=0) 247 | feats = torch.cat(GatherLayer.apply(feats), dim=0) 248 | feats = feats.detach().cpu().numpy() 249 | 250 | return feats 251 | 252 | def extract_fake_feature(generator, inception, num_gpus, device, latent_dim, fake_samples=50000, batch_size=16): 253 | num_batches = int(math.ceil(float(fake_samples) / float(batch_size * num_gpus))) 254 | feats = [] 255 | for _ in tqdm(range(num_batches)): 256 | z = [torch.randn([batch_size, latent_dim], device=device)] 257 | fake_img = generator(z) 258 | 259 | fake_img = normalize_fake_img(fake_img) 260 | 261 | feat = inception(fake_img) 262 | 263 | feats.append(feat) 264 | 265 | feats = gather_feats(feats) 266 | 267 | return feats 268 | 269 | def extract_fake_feature_t2i(data_loader, generator, inception, clip_text, device, latent_dim=100, mixed_flag=False): 270 | # with torch.cuda.amp.autocast() if mixed_flag else dummy_context_mgr() as mpc: 271 | with torch.no_grad(): 272 | feats = [] 273 | try: 274 | for img, tokens, txt in tqdm(data_loader): 275 | # pre-process 276 | tokens = tokens.to(device) 277 | sent_emb, word_emb = clip_text(tokens) # [bs, 512], [bs, 77, 512] 278 | sent_emb = sent_emb.detach() 279 | 280 | # make fake_img 281 | noise = torch.randn([sent_emb.shape[0], latent_dim]).to(device) 282 | fake_img = generator(noise, sent_emb) 283 | fake_img = fake_img.float() 284 | fake_img = torch.clamp(fake_img, -1., 1.) 285 | fake_img = torch.nan_to_num(fake_img, nan=-1.0, posinf=1.0, neginf=-1.0) 286 | 287 | # get features of inception 288 | fake_img = normalize_fake_img(fake_img) 289 | feat = inception(fake_img) 290 | 291 | # galip 292 | pred = feat[0] 293 | if pred.shape[2] != 1 or pred.shape[3] != 1: 294 | pred = adaptive_avg_pool2d(pred, output_size=(1, 1)) 295 | pred = pred.squeeze(-1).squeeze(-1) 296 | feats.append(pred) 297 | 298 | except IndexError: 299 | pass 300 | 301 | feats = gather_feats(feats) 302 | 303 | return feats 304 | 305 | def get_statistics(feats): 306 | mu = np.mean(feats, axis=0) 307 | cov = np.cov(feats, rowvar=False) 308 | 309 | return mu, cov 310 | 311 | def frechet_distance(mu, cov, mu2, cov2): 312 | cc, _ = linalg.sqrtm(np.dot(cov, cov2), disp=False) 313 | dist = np.sum((mu -mu2)**2) + np.trace(cov + cov2 - 2*cc) 314 | return np.real(dist) 315 | 316 | @torch.no_grad() 317 | def calculate_fid(data_loader, generator_model, inception_model, dataset_name, rank, device, 318 | latent_dim, fake_samples=50000, batch_size=16): 319 | 320 | num_gpus = torch.cuda.device_count() 321 | 322 | generator_model = generator_model.eval() 323 | inception_model = inception_model.eval() 324 | 325 | pickle_name = '{}_mu_cov.pickle'.format(dataset_name) 326 | cache = os.path.exists(pickle_name) 327 | 328 | if cache: 329 | with open(pickle_name, 'rb') as f: 330 | real_mu, real_cov = pickle.load(f) 331 | else: 332 | real_feats = extract_real_feature(data_loader, inception_model, device=device) 333 | real_mu, real_cov = get_statistics(real_feats) 334 | 335 | if rank == 0: 336 | with open(pickle_name, 'wb') as f: 337 | pickle.dump((real_mu, real_cov), f, protocol=pickle.HIGHEST_PROTOCOL) 338 | 339 | 340 | fake_feats = extract_fake_feature(generator_model, inception_model, num_gpus, device, latent_dim, fake_samples, batch_size) 341 | fake_mu, fake_cov = get_statistics(fake_feats) 342 | 343 | fid = frechet_distance(real_mu, real_cov, fake_mu, fake_cov) 344 | 345 | return fid 346 | 347 | @torch.no_grad() 348 | def calculate_fid_t2i(data_loader, generator, inception, clip_text, dataset_name, device, 349 | latent_dim=100, mixed_flag=False): 350 | # coco: 5000 351 | 352 | generator = generator.eval() 353 | inception = inception.eval() 354 | clip_text = clip_text.eval() 355 | 356 | stats_path = '{}_fid_stats.npz'.format(dataset_name) 357 | x = np.load(stats_path) 358 | real_mu, real_cov = x['mu'], x['sigma'] 359 | 360 | 361 | fake_feats = extract_fake_feature_t2i(data_loader, generator, inception, clip_text, device, latent_dim, mixed_flag) 362 | fake_mu, fake_cov = get_statistics(fake_feats) 363 | 364 | fid = frechet_distance(real_mu, real_cov, fake_mu, fake_cov) 365 | 366 | return fid -------------------------------------------------------------------------------- /src/GALIP/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import torch.nn.functional as F 5 | from collections import OrderedDict 6 | import clip 7 | # clip : CLIP4evl 8 | 9 | """ 10 | Compose( 11 | Resize(size=224, interpolation=bicubic, max_size=None, antialias=None) 12 | CenterCrop(size=(224, 224)) 13 | 14 | ToTensor() 15 | Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)) 16 | ) 17 | """ 18 | 19 | class dummy_context_mgr(): 20 | def __enter__(self): 21 | return None 22 | 23 | def __exit__(self, exc_type, exc_value, traceback): 24 | return False 25 | 26 | class CLIP_IMG_ENCODER(nn.Module): 27 | def __init__(self, CLIP): 28 | super(CLIP_IMG_ENCODER, self).__init__() 29 | model = CLIP.visual 30 | # print(model) 31 | self.define_module(model) 32 | for param in self.parameters(): 33 | param.requires_grad = False 34 | 35 | def define_module(self, model): 36 | self.conv1 = model.conv1 37 | self.class_embedding = model.class_embedding 38 | self.positional_embedding = model.positional_embedding 39 | self.ln_pre = model.ln_pre 40 | self.transformer = model.transformer 41 | self.ln_post = model.ln_post 42 | self.proj = model.proj 43 | 44 | @property 45 | def dtype(self): 46 | return self.conv1.weight.dtype 47 | 48 | def transf_to_CLIP_input(self,inputs): 49 | device = inputs.device 50 | if len(inputs.size()) != 4: 51 | raise ValueError('Expect the (B, C, X, Y) tensor.') 52 | else: 53 | mean = torch.tensor([0.48145466, 0.4578275, 0.40821073])\ 54 | .unsqueeze(-1).unsqueeze(-1).unsqueeze(0).to(device) 55 | var = torch.tensor([0.26862954, 0.26130258, 0.27577711])\ 56 | .unsqueeze(-1).unsqueeze(-1).unsqueeze(0).to(device) 57 | inputs = F.interpolate(inputs*0.5+0.5, size=(224, 224)) 58 | # inputs = ((inputs+1)*0.5-mean)/var 59 | inputs = (inputs - mean) / var 60 | return inputs 61 | 62 | def forward(self, img: torch.Tensor): 63 | x = self.transf_to_CLIP_input(img) 64 | x = x.type(self.dtype) 65 | x = self.conv1(x) # shape = [*, width, grid, grid] 66 | grid = x.size(-1) 67 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 68 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 69 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] 70 | x = x + self.positional_embedding.to(x.dtype) 71 | x = self.ln_pre(x) 72 | # NLD -> LND 73 | x = x.permute(1, 0, 2) 74 | # Local features 75 | #selected = [1,4,7,12] 76 | selected = [1,4,8] 77 | local_features = [] 78 | for i in range(12): 79 | x = self.transformer.resblocks[i](x) 80 | if i in selected: 81 | local_features.append(x.permute(1, 0, 2)[:, 1:, :].permute(0, 2, 1).reshape(-1, 768, grid, grid).contiguous().type(img.dtype)) 82 | x = x.permute(1, 0, 2) # LND -> NLD 83 | x = self.ln_post(x[:, 0, :]) 84 | if self.proj is not None: 85 | x = x @ self.proj 86 | return torch.stack(local_features, dim=1), x.type(img.dtype) 87 | 88 | 89 | class CLIP_TXT_ENCODER(nn.Module): 90 | def __init__(self, CLIP): 91 | super(CLIP_TXT_ENCODER, self).__init__() 92 | self.define_module(CLIP) 93 | # print(model) 94 | for param in self.parameters(): 95 | param.requires_grad = False 96 | 97 | def define_module(self, CLIP): 98 | self.transformer = CLIP.transformer 99 | self.vocab_size = CLIP.vocab_size 100 | self.token_embedding = CLIP.token_embedding 101 | self.positional_embedding = CLIP.positional_embedding 102 | self.ln_final = CLIP.ln_final 103 | self.text_projection = CLIP.text_projection 104 | 105 | @property 106 | def dtype(self): 107 | return self.transformer.resblocks[0].mlp.c_fc.weight.dtype 108 | 109 | def forward(self, text): 110 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 111 | x = x + self.positional_embedding.type(self.dtype) 112 | x = x.permute(1, 0, 2) # NLD -> LND 113 | x = self.transformer(x) 114 | x = x.permute(1, 0, 2) # LND -> NLD 115 | x = self.ln_final(x).type(self.dtype) 116 | # x.shape = [batch_size, n_ctx, transformer.width] 117 | # take features from the eot embedding (eot_token is the highest number in each sequence) 118 | sent_emb = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 119 | return sent_emb, x 120 | 121 | 122 | class CLIP_Mapper(nn.Module): 123 | def __init__(self, CLIP): 124 | super(CLIP_Mapper, self).__init__() 125 | model = CLIP.visual 126 | # print(model) 127 | self.define_module(model) 128 | for param in model.parameters(): 129 | param.requires_grad = False 130 | 131 | def define_module(self, model): 132 | self.conv1 = model.conv1 133 | self.class_embedding = model.class_embedding 134 | self.positional_embedding = model.positional_embedding 135 | self.ln_pre = model.ln_pre 136 | self.transformer = model.transformer 137 | 138 | @property 139 | def dtype(self): 140 | return self.conv1.weight.dtype 141 | 142 | def forward(self, img: torch.Tensor, prompts: torch.Tensor): 143 | x = img.type(self.dtype) 144 | prompts = prompts.type(self.dtype) 145 | grid = x.size(-1) 146 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 147 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 148 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) 149 | # shape = [*, grid ** 2 + 1, width] 150 | x = x + self.positional_embedding.to(x.dtype) 151 | x = self.ln_pre(x) 152 | # NLD -> LND 153 | x = x.permute(1, 0, 2) 154 | # Local features 155 | selected = [1,2,3,4,5,6,7,8] 156 | begin, end = 0, 12 157 | prompt_idx = 0 158 | for i in range(begin, end): 159 | if i in selected: 160 | prompt = prompts[:,prompt_idx,:].unsqueeze(0) 161 | prompt_idx = prompt_idx+1 162 | x = torch.cat((x,prompt), dim=0) 163 | x = self.transformer.resblocks[i](x) 164 | x = x[:-1,:,:] 165 | else: 166 | x = self.transformer.resblocks[i](x) 167 | return x.permute(1, 0, 2)[:, 1:, :].permute(0, 2, 1).reshape(-1, 768, grid, grid).contiguous().type(img.dtype) 168 | 169 | 170 | class CLIP_Adapter(nn.Module): 171 | def __init__(self, in_ch, mid_ch, out_ch, G_ch, CLIP_ch, cond_dim, k, s, p, map_num, CLIP): 172 | super(CLIP_Adapter, self).__init__() 173 | self.CLIP_ch = CLIP_ch 174 | self.FBlocks = nn.ModuleList([]) 175 | self.FBlocks.append(M_Block(in_ch, mid_ch, out_ch, cond_dim, k, s, p)) 176 | for i in range(map_num-1): 177 | self.FBlocks.append(M_Block(out_ch, mid_ch, out_ch, cond_dim, k, s, p)) 178 | self.conv_fuse = nn.Conv2d(out_ch, CLIP_ch, 5, 1, 2) 179 | self.CLIP_ViT = CLIP_Mapper(CLIP) 180 | self.conv = nn.Conv2d(768, G_ch, 5, 1, 2) 181 | # 182 | self.fc_prompt = nn.Linear(cond_dim, CLIP_ch*8) 183 | 184 | def forward(self,out,c): 185 | prompts = self.fc_prompt(c).view(c.size(0),-1,self.CLIP_ch) 186 | # [1, 8, 768] 187 | for FBlock in self.FBlocks: 188 | out = FBlock(out,c) 189 | # out -> [1, 64, 7, 7] 190 | fuse_feat = self.conv_fuse(out) 191 | # fuse_feat -> [1, 768, 7, 7] 192 | map_feat = self.CLIP_ViT(fuse_feat,prompts) 193 | # map_feat -> [1, 768, 7, 7] 194 | return self.conv(fuse_feat+0.1*map_feat) # [1, 512, 7, 7] 195 | 196 | 197 | class NetG(nn.Module): 198 | def __init__(self, imsize, CLIP, ngf=64, nz=100, cond_dim=512, ch_size=3, mixed_precision=False): 199 | super(NetG, self).__init__() 200 | self.ngf = ngf 201 | self.mixed_precision = mixed_precision 202 | # build CLIP Mapper 203 | self.code_sz, self.code_ch, self.mid_ch = 7, 64, 32 204 | self.CLIP_ch = 768 205 | self.fc_code = nn.Linear(nz, self.code_sz*self.code_sz*self.code_ch) 206 | self.mapping = CLIP_Adapter(self.code_ch, self.mid_ch, self.code_ch, ngf*8, self.CLIP_ch, cond_dim+nz, 3, 1, 1, 4, CLIP) 207 | # build GBlocks 208 | self.GBlocks = nn.ModuleList([]) 209 | in_out_pairs = list(get_G_in_out_chs(ngf, imsize)) 210 | imsize = 4 211 | for idx, (in_ch, out_ch) in enumerate(in_out_pairs): 212 | if idx<(len(in_out_pairs)-1): 213 | imsize = imsize*2 214 | else: 215 | imsize = 224 216 | self.GBlocks.append(G_Block(cond_dim+nz, in_ch, out_ch, imsize)) 217 | # to RGB image 218 | self.to_rgb = nn.Sequential( 219 | nn.LeakyReLU(0.2,inplace=True), 220 | nn.Conv2d(out_ch, ch_size, 3, 1, 1), 221 | #nn.Tanh(), 222 | ) 223 | 224 | def forward(self, noise, c, eval=False): # x=noise, c=ent_emb 225 | with torch.cuda.amp.autocast() if self.mixed_precision and not eval else dummy_context_mgr() as mp: 226 | cond = torch.cat((noise, c), dim=1) # 612 dim, 100 + 512 227 | out = self.mapping(self.fc_code(noise).view(noise.size(0), self.code_ch, self.code_sz, self.code_sz), cond) 228 | # fc_code -> [1, 64, 7, 7] 229 | # out -> [1, 512, 7, 7] 230 | # fuse text and visual features 231 | # 이미지 늘리기 232 | for GBlock in self.GBlocks: 233 | out = GBlock(out, cond) 234 | # [1, 64, 224, 224] 235 | # convert to RGB image 236 | out = self.to_rgb(out) 237 | return out 238 | 239 | 240 | class NetD(nn.Module): 241 | def __init__(self, imsize, ndf=64, ch_size=3, mixed_precision=False): 242 | super(NetD, self).__init__() 243 | self.mixed_precision = mixed_precision 244 | self.DBlocks = nn.ModuleList([ 245 | D_Block(768, 768, 3, 1, 1, res=True, CLIP_feat=True), 246 | D_Block(768, 768, 3, 1, 1, res=True, CLIP_feat=True), 247 | ]) 248 | self.main = D_Block(768, 512, 3, 1, 1, res=True, CLIP_feat=False) 249 | 250 | def forward(self, h): 251 | with torch.cuda.amp.autocast() if self.mixed_precision else dummy_context_mgr() as mpc: 252 | out = h[:,0] 253 | for idx in range(len(self.DBlocks)): 254 | out = self.DBlocks[idx](out, h[:,idx+1]) 255 | out = self.main(out) 256 | return out 257 | 258 | 259 | class NetC(nn.Module): 260 | def __init__(self, ndf=64, cond_dim=512, mixed_precision=False): 261 | super(NetC, self).__init__() 262 | self.cond_dim = cond_dim 263 | self.mixed_precision = mixed_precision 264 | self.joint_conv = nn.Sequential( 265 | nn.Conv2d(512+512, 128, 4, 1, 0, bias=False), 266 | nn.LeakyReLU(0.2, inplace=True), 267 | nn.Conv2d(128, 1, 4, 1, 0, bias=False), 268 | ) 269 | 270 | def forward(self, out, cond): 271 | with torch.cuda.amp.autocast() if self.mixed_precision else dummy_context_mgr() as mpc: 272 | cond = cond.view(-1, self.cond_dim, 1, 1) 273 | cond = cond.repeat(1, 1, 7, 7) 274 | h_c_code = torch.cat((out, cond), 1) 275 | out = self.joint_conv(h_c_code) 276 | return out 277 | 278 | 279 | class M_Block(nn.Module): 280 | def __init__(self, in_ch, mid_ch, out_ch, cond_dim, k, s, p): 281 | super(M_Block, self).__init__() 282 | self.conv1 = nn.Conv2d(in_ch, mid_ch, k, s, p) 283 | self.fuse1 = DFBLK(cond_dim, mid_ch) 284 | self.conv2 = nn.Conv2d(mid_ch, out_ch, k, s, p) 285 | self.fuse2 = DFBLK(cond_dim, out_ch) 286 | self.learnable_sc = in_ch != out_ch 287 | if self.learnable_sc: 288 | self.c_sc = nn.Conv2d(in_ch, out_ch, 1, stride=1, padding=0) 289 | 290 | def shortcut(self, x): 291 | if self.learnable_sc: 292 | x = self.c_sc(x) 293 | return x 294 | 295 | def residual(self, h, text): 296 | h = self.conv1(h) 297 | h = self.fuse1(h, text) 298 | h = self.conv2(h) 299 | h = self.fuse2(h, text) 300 | return h 301 | 302 | def forward(self, h, c): 303 | return self.shortcut(h) + self.residual(h, c) 304 | 305 | 306 | class G_Block(nn.Module): 307 | def __init__(self, cond_dim, in_ch, out_ch, imsize): 308 | super(G_Block, self).__init__() 309 | self.imsize = imsize 310 | self.learnable_sc = in_ch != out_ch 311 | self.c1 = nn.Conv2d(in_ch, out_ch, 3, 1, 1) 312 | self.c2 = nn.Conv2d(out_ch, out_ch, 3, 1, 1) 313 | self.fuse1 = DFBLK(cond_dim, in_ch) 314 | self.fuse2 = DFBLK(cond_dim, out_ch) 315 | if self.learnable_sc: 316 | self.c_sc = nn.Conv2d(in_ch,out_ch, 1, stride=1, padding=0) 317 | 318 | def shortcut(self, x): 319 | if self.learnable_sc: 320 | x = self.c_sc(x) 321 | return x 322 | 323 | def residual(self, h, y): 324 | h = self.fuse1(h, y) 325 | h = self.c1(h) 326 | h = self.fuse2(h, y) 327 | h = self.c2(h) 328 | return h 329 | 330 | def forward(self, h, y): 331 | h = F.interpolate(h, size=(self.imsize, self.imsize)) 332 | return self.shortcut(h) + self.residual(h, y) 333 | 334 | 335 | class D_Block(nn.Module): 336 | def __init__(self, fin, fout, k, s, p, res, CLIP_feat): 337 | super(D_Block, self).__init__() 338 | self.res, self.CLIP_feat = res, CLIP_feat 339 | self.learned_shortcut = (fin != fout) 340 | self.conv_r = nn.Sequential( 341 | nn.Conv2d(fin, fout, k, s, p, bias=False), 342 | nn.LeakyReLU(0.2, inplace=True), 343 | nn.Conv2d(fout, fout, k, s, p, bias=False), 344 | nn.LeakyReLU(0.2, inplace=True), 345 | ) 346 | self.conv_s = nn.Conv2d(fin, fout, 1, stride=1, padding=0) 347 | if self.res==True: 348 | self.gamma = nn.Parameter(torch.zeros(1)) 349 | if self.CLIP_feat==True: 350 | self.beta = nn.Parameter(torch.zeros(1)) 351 | 352 | def forward(self, x, CLIP_feat=None): 353 | res = self.conv_r(x) 354 | if self.learned_shortcut: 355 | x = self.conv_s(x) 356 | if (self.res==True)and(self.CLIP_feat==True): 357 | return x + self.gamma*res + self.beta*CLIP_feat 358 | elif (self.res==True)and(self.CLIP_feat!=True): 359 | return x + self.gamma*res 360 | elif (self.res!=True)and(self.CLIP_feat==True): 361 | return x + self.beta*CLIP_feat 362 | else: 363 | return x 364 | 365 | 366 | class DFBLK(nn.Module): 367 | def __init__(self, cond_dim, in_ch): 368 | super(DFBLK, self).__init__() 369 | self.affine0 = Affine(cond_dim, in_ch) 370 | self.affine1 = Affine(cond_dim, in_ch) 371 | 372 | def forward(self, x, y=None): 373 | h = self.affine0(x, y) 374 | h = nn.LeakyReLU(0.2,inplace=True)(h) 375 | h = self.affine1(h, y) 376 | h = nn.LeakyReLU(0.2,inplace=True)(h) 377 | return h 378 | 379 | 380 | class QuickGELU(nn.Module): 381 | def forward(self, x: torch.Tensor): 382 | return x * torch.sigmoid(1.702 * x) 383 | 384 | 385 | class Affine(nn.Module): 386 | def __init__(self, cond_dim, num_features): 387 | super(Affine, self).__init__() 388 | 389 | self.fc_gamma = nn.Sequential(OrderedDict([ 390 | ('linear1',nn.Linear(cond_dim, num_features)), 391 | ('relu1',nn.ReLU(inplace=True)), 392 | ('linear2',nn.Linear(num_features, num_features)), 393 | ])) 394 | self.fc_beta = nn.Sequential(OrderedDict([ 395 | ('linear1',nn.Linear(cond_dim, num_features)), 396 | ('relu1',nn.ReLU(inplace=True)), 397 | ('linear2',nn.Linear(num_features, num_features)), 398 | ])) 399 | self._initialize() 400 | 401 | def _initialize(self): 402 | nn.init.zeros_(self.fc_gamma.linear2.weight.data) 403 | nn.init.ones_(self.fc_gamma.linear2.bias.data) 404 | nn.init.zeros_(self.fc_beta.linear2.weight.data) 405 | nn.init.zeros_(self.fc_beta.linear2.bias.data) 406 | 407 | def forward(self, x, y=None): 408 | weight = self.fc_gamma(y) 409 | bias = self.fc_beta(y) 410 | 411 | if weight.dim() == 1: 412 | weight = weight.unsqueeze(0) 413 | if bias.dim() == 1: 414 | bias = bias.unsqueeze(0) 415 | 416 | size = x.size() 417 | weight = weight.unsqueeze(-1).unsqueeze(-1).expand(size) 418 | bias = bias.unsqueeze(-1).unsqueeze(-1).expand(size) 419 | return weight * x + bias 420 | 421 | 422 | def get_G_in_out_chs(nf, imsize): 423 | layer_num = int(np.log2(imsize))-1 424 | channel_nums = [nf*min(2**idx, 8) for idx in range(layer_num)] 425 | channel_nums = channel_nums[::-1] 426 | in_out_pairs = zip(channel_nums[:-1], channel_nums[1:]) 427 | return in_out_pairs 428 | 429 | 430 | def get_D_in_out_chs(nf, imsize): 431 | layer_num = int(np.log2(imsize))-1 432 | channel_nums = [nf*min(2**idx, 8) for idx in range(layer_num)] 433 | in_out_pairs = zip(channel_nums[:-1], channel_nums[1:]) 434 | return in_out_pairs -------------------------------------------------------------------------------- /src/GALIP/ops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | 5 | from torchvision import transforms 6 | 7 | def get_ratio(reg_every): 8 | if reg_every == 0: 9 | reg_ratio = 1 10 | else: 11 | reg_ratio = float(reg_every) / float(reg_every + 1) 12 | 13 | return reg_ratio 14 | 15 | def apply_gradients(loss, optim, mixed_flag=False, scaler_x=None, scaler_min=None): 16 | optim.zero_grad() 17 | 18 | if mixed_flag: 19 | scaler_x.scale(loss).backward() 20 | scaler_x.step(optim) 21 | scaler_x.update() 22 | if scaler_x.get_scale() < scaler_min: 23 | scaler_x.update(16384.0) 24 | else: 25 | loss.backward() 26 | optim.step() 27 | 28 | def moving_average(ema_model, origin_model, decay=0.999): 29 | # model1 = ema 30 | # model2 = origin 31 | 32 | with torch.no_grad(): 33 | ema_param = dict(ema_model.named_parameters()) 34 | origin_param = dict(origin_model.named_parameters()) 35 | 36 | for k in ema_param.keys(): 37 | ema_param[k].data.mul_(decay).add_(origin_param[k].data, alpha=1 - decay) 38 | # ema_param[k].data = decay * ema_param[k].data + (1 - decay) * origin_param[k].data 39 | 40 | def d_hinge_loss(real_pred, fake_pred, fake_pred2): 41 | real_loss = torch.mean(F.relu(1.0 - real_pred)) 42 | fake_loss = torch.mean(F.relu(1.0 + fake_pred)) 43 | if fake_pred2 is None: 44 | d_loss = real_loss + fake_loss 45 | else: 46 | fake_loss2 = torch.mean(F.relu(1.0 + fake_pred2)) 47 | fake_loss = (fake_loss + fake_loss2) * 0.5 48 | d_loss = real_loss + fake_loss 49 | return d_loss 50 | 51 | def g_hinge_loss(fake_pred): 52 | g_loss = -torch.mean(fake_pred) 53 | return g_loss 54 | 55 | 56 | def d_logistic_loss(real_pred, fake_pred, fake_pred2): 57 | real_loss = F.softplus(-real_pred) 58 | fake_loss = F.softplus(fake_pred) 59 | 60 | if fake_pred2 is None: 61 | return real_loss.mean() + fake_loss.mean() 62 | else: 63 | fake_loss2 = F.softplus(fake_pred2) 64 | return real_loss.mean() + (fake_loss.mean() + fake_loss2.mean()) * 0.5 65 | 66 | def d_r1_loss(logits, real_img, text_embed=None): 67 | if text_embed is None: 68 | grad_real = torch.autograd.grad( 69 | outputs=logits.sum(), 70 | inputs=real_img, 71 | create_graph=True, 72 | )[0] 73 | grad_penalty = (grad_real ** 2).reshape(grad_real.shape[0], -1).sum(1).mean() 74 | 75 | else: 76 | grads = torch.autograd.grad( 77 | outputs=logits.sum(), 78 | inputs=(real_img, text_embed), 79 | create_graph=True, 80 | ) 81 | grad0 = grads[0].view(grads[0].size(0), -1) 82 | grad1 = grads[1].view(grads[1].size(0), -1) 83 | grad = torch.cat((grad0, grad1), dim=1) 84 | # norm은 torch.sqrt((grad ** 2).sum(1)) 임 85 | grad_penalty = (grad ** 2).sum(1).mean() 86 | 87 | return grad_penalty 88 | 89 | def g_nonsaturating_loss(fake_pred): 90 | loss = F.softplus(-fake_pred).mean() 91 | 92 | return loss 93 | 94 | def d_adv_loss(real_pred, fake_pred, fake_pred2=None, gan_type='gan'): 95 | if gan_type == 'hinge': 96 | loss = d_hinge_loss(real_pred, fake_pred, fake_pred2) 97 | else: 98 | loss = d_logistic_loss(real_pred, fake_pred, fake_pred2) 99 | 100 | return loss 101 | 102 | def g_adv_loss(fake_pred, gan_type='gan'): 103 | if gan_type == 'hinge': 104 | loss = g_hinge_loss(fake_pred) 105 | else: 106 | loss = g_nonsaturating_loss(fake_pred) 107 | 108 | return loss 109 | 110 | 111 | def predict_loss(predictor, img_feature, text_feature, negtive): 112 | output = predictor(img_feature, text_feature) 113 | err = hinge_loss(output, negtive) 114 | return output,err 115 | 116 | def hinge_loss(output, negtive): 117 | if negtive==False: 118 | err = torch.mean(F.relu(1. - output)) 119 | else: 120 | err = torch.mean(F.relu(1. + output)) 121 | return err 122 | 123 | def MA_GP_FP32(img, sent, out): 124 | grads = torch.autograd.grad(outputs=out, 125 | inputs=(img, sent), 126 | grad_outputs=torch.ones(out.size()).cuda(), 127 | retain_graph=True, 128 | create_graph=True, 129 | only_inputs=True) 130 | grad0 = grads[0].view(grads[0].size(0), -1) 131 | grad1 = grads[1].view(grads[1].size(0), -1) 132 | grad = torch.cat((grad0,grad1),dim=1) 133 | grad_l2norm = torch.sqrt(torch.sum(grad ** 2, dim=1)) 134 | d_loss_gp = 2.0 * torch.mean((grad_l2norm) ** 6) 135 | return d_loss_gp 136 | 137 | def MA_GP_MP(img, sent, out, scaler): 138 | grads = torch.autograd.grad(outputs=scaler.scale(out), 139 | inputs=(img, sent), 140 | grad_outputs=torch.ones_like(out), 141 | retain_graph=True, 142 | create_graph=True, 143 | only_inputs=True) 144 | inv_scale = 1./(scaler.get_scale()+float("1e-8")) 145 | #inv_scale = 1./scaler.get_scale() 146 | grads = [grad * inv_scale for grad in grads] 147 | with torch.cuda.amp.autocast(): 148 | grad0 = grads[0].view(grads[0].size(0), -1) 149 | grad1 = grads[1].view(grads[1].size(0), -1) 150 | grad = torch.cat((grad0,grad1),dim=1) 151 | grad_l2norm = torch.sqrt(torch.sum(grad ** 2, dim=1)) 152 | d_loss_gp = 2.0 * torch.mean((grad_l2norm) ** 6) 153 | return d_loss_gp 154 | 155 | # clip loss 156 | def clip_image_process(x): 157 | def denormalize(x): 158 | # [-1, 1] ~ [0, 255] 159 | x = ((x + 1) / 2 * 255).clamp(0, 255).to(torch.uint8) 160 | 161 | return x 162 | 163 | def resize(x): 164 | x = transforms.Resize(size=[224, 224], interpolation=transforms.InterpolationMode.BICUBIC, antialias=True)(x) 165 | return x 166 | 167 | def zero_to_one(x): 168 | x = x.float() / 255.0 169 | return x 170 | 171 | def norm_mean_std(x): 172 | mean = [0.48145466, 0.4578275, 0.40821073] 173 | std = [0.26862954, 0.26130258, 0.27577711] 174 | x = transforms.Normalize(mean=mean, std=std, inplace=True)(x) 175 | return x 176 | 177 | 178 | x = denormalize(x) 179 | x = resize(x) 180 | x = zero_to_one(x) 181 | x = norm_mean_std(x) 182 | 183 | return x 184 | 185 | def cosine_sim_loss(image_feat, text_feat): 186 | image_feat = image_feat / image_feat.norm(p=2, dim=-1, keepdim=True) 187 | text_feat = text_feat / text_feat.norm(p=2, dim=-1, keepdim=True) 188 | 189 | loss = -F.cosine_similarity(image_feat, text_feat).mean() 190 | return loss 191 | 192 | def clip_score(clip_model, image, text): 193 | txt_features = clip_model.get_text_features(text) 194 | 195 | processed_image = clip_image_process(image) 196 | img_features = clip_model.get_image_features(processed_image) 197 | 198 | img_features = img_features / img_features.norm(p=2, dim=-1, keepdim=True) 199 | txt_features = txt_features / txt_features.norm(p=2, dim=-1, keepdim=True) 200 | 201 | # score = 100 * (img_features * txt_features).sum(axis=-1) 202 | # score = torch.mean(score) 203 | 204 | score = -F.cosine_similarity(img_features, txt_features).mean() 205 | 206 | return score 207 | 208 | def clip_image_score(clip_model, image1, image2): 209 | processed_image1 = clip_image_process(image1) 210 | processed_image2 = clip_image_process(image2) 211 | 212 | img_features1 = clip_model.get_image_features(processed_image1) 213 | img_features2 = clip_model.get_image_features(processed_image2) 214 | 215 | img_features1 = img_features1 / img_features1.norm(p=2, dim=-1, keepdim=True) 216 | img_features2 = img_features2 / img_features2.norm(p=2, dim=-1, keepdim=True) 217 | 218 | # score = 100 * (img_features1 * img_features2).sum(axis=-1) 219 | # score = torch.mean(score) 220 | 221 | score = -F.cosine_similarity(img_features1, img_features2).mean() 222 | 223 | return score 224 | 225 | def contrastive_loss(logits, dim) : 226 | neg_ce = torch.diag(nn.functional.log_softmax(logits, dim=dim)) 227 | return -neg_ce.mean() 228 | 229 | def clip_score_(clip_model, image, text): 230 | txt_features = clip_model.get_text_features(text) 231 | 232 | processed_image = clip_image_process(image) 233 | img_features = clip_model.get_image_features(processed_image) 234 | 235 | img_features = img_features / img_features.norm(p=2, dim=-1, keepdim=True) 236 | txt_features = txt_features / txt_features.norm(p=2, dim=-1, keepdim=True) 237 | 238 | # cosine similarity as logits 239 | logit_scale = clip_model.logit_scale.exp() 240 | similarity = torch.matmul(txt_features, img_features.t()) * logit_scale 241 | 242 | caption_loss = contrastive_loss(similarity, dim=0) 243 | image_loss = contrastive_loss(similarity, dim=1) 244 | 245 | return (caption_loss + image_loss) / 2.0 246 | 247 | def clip_image_score_(clip_model, image1, image2): 248 | processed_image1 = clip_image_process(image1) 249 | processed_image2 = clip_image_process(image2) 250 | 251 | img_features1 = clip_model.get_image_features(processed_image1) 252 | img_features2 = clip_model.get_image_features(processed_image2) 253 | 254 | img_features1 = img_features1 / img_features1.norm(p=2, dim=-1, keepdim=True) 255 | img_features2 = img_features2 / img_features2.norm(p=2, dim=-1, keepdim=True) 256 | 257 | # cosine similarity as logits 258 | logit_scale = clip_model.logit_scale.exp() 259 | similarity = torch.matmul(img_features1, img_features2.t()) * logit_scale 260 | 261 | caption_loss = contrastive_loss(similarity, dim=0) 262 | image_loss = contrastive_loss(similarity, dim=1) 263 | 264 | return (caption_loss + image_loss) / 2.0 265 | 266 | def convert_to_billion_and_million(value, decimal_places=2): 267 | billion = round(value / 1_000_000_000, decimal_places) 268 | million = round(value / 1_000_000, decimal_places) 269 | 270 | return billion, million -------------------------------------------------------------------------------- /src/GALIP/utils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import torch 4 | from torchvision import transforms 5 | from torch.utils.data import Dataset 6 | from PIL import Image 7 | import os, re 8 | from glob import glob 9 | import torch.distributed as dist 10 | from torch.nn.parallel import DistributedDataParallel 11 | import torch.multiprocessing as torch_multiprocessing 12 | 13 | import json 14 | import requests 15 | import traceback 16 | 17 | from transformers import CLIPTokenizer, CLIPModel 18 | import torch.nn as nn 19 | 20 | import numpy as np 21 | import torch.nn.functional as F 22 | import random 23 | import clip 24 | 25 | class ImageTextDataset(Dataset): 26 | def __init__(self, dataset_path, img_size, imagenet_normalization=False, max_length=77): 27 | self.image_samples, self.text_samples = self.listdir(dataset_path) 28 | self.max_length = max_length 29 | 30 | transform_list = image_preprocess(img_size, imagenet_normalization) 31 | self.transform = transforms.Compose(transform_list) 32 | 33 | # self.tokenizer, self.clip = FrozenNetwork(max_length=max_length).load() 34 | 35 | def listdir(self, dir_path): 36 | img_extensions = ['png', 'jpg', 'jpeg', 'JPG'] 37 | image_list = [] 38 | for ext in img_extensions: 39 | image_list += glob(os.path.join(dir_path, 'image', '*.' + ext)) 40 | image_list.sort() 41 | 42 | txt_extensions = ['txt'] 43 | text_list = [] 44 | for ext in txt_extensions: 45 | text_list += glob(os.path.join(dir_path, 'text', '*.' + ext)) 46 | text_list.sort() 47 | 48 | return image_list, text_list 49 | 50 | def __getitem__(self, index): 51 | image_path, text_path = self.image_samples[index], self.text_samples[index] 52 | img = Image.open(image_path).convert('RGB') 53 | txt = text_read(text_path) 54 | 55 | img = self.transform(img) 56 | 57 | # batch_encoding = self.tokenizer(txt, truncation=True, max_length=self.max_length, padding="max_length", return_tensors="pt") 58 | # tokens = batch_encoding["input_ids"] # [1, 77] 59 | tokens = clip.tokenize(txt, truncate=True) 60 | tokens = torch.squeeze(tokens) 61 | # tokens = tokens.to(self.clip_text_encoder.device) 62 | # outputs = self.clip_text_encoder(input_ids=tokens) 63 | # txt_embed = outputs.last_hidden_state # [77, 768] 64 | 65 | return img, tokens, txt 66 | 67 | def __len__(self): 68 | return len(self.image_samples) 69 | 70 | class ImageDataset(Dataset): 71 | def __init__(self, dataset_path, img_size, imagenet_normalization=False): 72 | self.image_samples = self.listdir(dataset_path) 73 | 74 | transform_list = image_preprocess(img_size, imagenet_normalization) 75 | self.transform = transforms.Compose(transform_list) 76 | 77 | def listdir(self, dir_path): 78 | img_extensions = ['png', 'jpg', 'jpeg', 'JPG'] 79 | image_list = [] 80 | for ext in img_extensions: 81 | image_list += glob(os.path.join(dir_path, 'image', '*.' + ext)) 82 | image_list.sort() 83 | 84 | return image_list 85 | 86 | def __getitem__(self, index): 87 | image_path = self.image_samples[index] 88 | img = Image.open(image_path).convert('RGB') 89 | 90 | img = self.transform(img) 91 | 92 | return img 93 | 94 | def __len__(self): 95 | return len(self.image_samples) 96 | 97 | class FrozenNetwork(nn.Module): 98 | """Load Clip encoder (for text), SD-Autoencoder (for image)""" 99 | # https://github.com/baofff/U-ViT/blob/f0f35a9e710688ec669ae7154c490a8053f3139f/libs/clip.py 100 | def __init__(self, autoencoder_version="runwayml/stable-diffusion-v1-5", clip_version="openai/clip-vit-large-patch14", max_length=77): 101 | super().__init__() 102 | self.max_length = max_length 103 | 104 | self.tokenizer = CLIPTokenizer.from_pretrained(clip_version) 105 | self.clip = CLIPModel.from_pretrained(clip_version) 106 | 107 | self.freeze() 108 | 109 | def freeze(self): 110 | self.clip.eval() 111 | 112 | def load(self): 113 | return self.tokenizer, self.clip 114 | 115 | 116 | def image_preprocess(img_size, imagenet_normalization=False): 117 | # interpolation=transforms.InterpolationMode.BICUBIC, antialias=True 118 | if imagenet_normalization: 119 | mean = [0.485, 0.456, 0.406] 120 | std = [0.229, 0.224, 0.225] 121 | transform_list = [ 122 | transforms.Resize(size=[img_size, img_size], interpolation=transforms.InterpolationMode.BICUBIC, antialias=True), # [h, w] 123 | transforms.ToTensor(), # [0, 255] -> [0, 1] # [c, h, w] 124 | transforms.Normalize(mean=mean, std=std, inplace=True), # [0, 1] -> [-1, 1] 125 | ] 126 | else: 127 | transform_list = [ 128 | transforms.Resize(size=[img_size, img_size], interpolation=transforms.InterpolationMode.BICUBIC, antialias=True), 129 | transforms.ToTensor(), # [0, 255] -> [0, 1] 130 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), # [0, 1] -> [-1, 1] 131 | ] 132 | 133 | return transform_list 134 | 135 | def text_read(text_path): 136 | with open(text_path, 'r') as f: 137 | x = f.readlines() 138 | 139 | t = [text.strip() for text in x] # remove \n 140 | 141 | t_sample = random.choice(t) 142 | 143 | return t_sample 144 | 145 | 146 | def check_folder(log_dir): 147 | if not os.path.exists(log_dir): 148 | os.makedirs(log_dir) 149 | return log_dir 150 | 151 | 152 | def str2bool(x): 153 | return x.lower() in ('true') 154 | 155 | 156 | def multi_gpu_run(ddp_fn, args): # in main 157 | # ddp_fn = train_fn 158 | world_size = torch.cuda.device_count() # ngpus 159 | torch_multiprocessing.spawn(fn=ddp_fn, args=(args, world_size), nprocs=world_size, join=True) 160 | 161 | 162 | def build_init_procss(rank, world_size, device): # in build 163 | os.environ["MASTER_ADDR"] = "127.0.0.1" # localhost 164 | os.environ["MASTER_PORT"] = "12355" 165 | dist.init_process_group("nccl", rank=rank, world_size=world_size) 166 | synchronize() 167 | torch.cuda.set_device(device) 168 | 169 | 170 | def distributed_sampler(dataset, rank, num_replicas, shuffle): 171 | return torch.utils.data.distributed.DistributedSampler(dataset, rank=rank, num_replicas=num_replicas, shuffle=shuffle) 172 | # return torch.utils.data.RandomSampler(dataset) 173 | 174 | 175 | def infinite_iterator(loader): 176 | while True: 177 | for batch in loader: 178 | yield batch 179 | 180 | def find_latest_ckpt(folder): 181 | files = [] 182 | for fname in os.listdir(folder): 183 | s = re.findall(r'\d+', fname) 184 | if len(s) == 1: 185 | files.append((int(s[0]), fname)) 186 | if files: 187 | file_name = max(files)[1] 188 | index = os.path.splitext(file_name)[0] 189 | return file_name, index 190 | else: 191 | return None, 0 192 | 193 | 194 | def broadcast_params(model): 195 | params = model.parameters() 196 | for param in params: 197 | dist.broadcast(param.data, src=0) 198 | dist.barrier() 199 | torch.cuda.synchronize() 200 | 201 | 202 | def dataparallel_and_sync(model, local_rank, find_unused_parameters=False): 203 | # DistributedDataParallel 204 | model = DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=find_unused_parameters) 205 | 206 | # broadcast 207 | broadcast_params(model) 208 | 209 | model = model.module 210 | 211 | return model 212 | 213 | def cleanup(): 214 | dist.destroy_process_group() 215 | 216 | def get_rank(): 217 | if not dist.is_available(): 218 | return 0 219 | 220 | if not dist.is_initialized(): 221 | return 0 222 | 223 | return dist.get_rank() 224 | 225 | def get_world_size(): 226 | if not dist.is_available(): 227 | return 1 228 | 229 | if not dist.is_initialized(): 230 | return 1 231 | 232 | return dist.get_world_size() 233 | 234 | def synchronize(): 235 | if not dist.is_available(): 236 | return 237 | 238 | if not dist.is_initialized(): 239 | return 240 | 241 | world_size = dist.get_world_size() 242 | 243 | if world_size == 1: 244 | return 245 | 246 | dist.barrier() 247 | 248 | def reduce_loss_dict(loss_dict): 249 | world_size = get_world_size() 250 | 251 | if world_size < 2: 252 | return loss_dict 253 | 254 | with torch.no_grad(): 255 | keys = [] 256 | losses = [] 257 | 258 | for k in sorted(loss_dict.keys()): 259 | keys.append(k) 260 | losses.append(loss_dict[k]) 261 | 262 | losses = torch.stack(losses, 0) 263 | dist.reduce(losses, dst=0) 264 | 265 | if dist.get_rank() == 0: 266 | losses /= world_size 267 | 268 | reduced_losses = {k: v.mean().item() for k, v in zip(keys, losses)} 269 | 270 | return reduced_losses 271 | 272 | 273 | def dict_to_numpy(x_dict, python_value=False): 274 | losses_numpy = {} 275 | for k,v in x_dict.items(): 276 | losses_numpy[k] = tensor_to_numpy(v, python_value=python_value) 277 | 278 | return losses_numpy 279 | 280 | def tensor_to_numpy(x, python_value=False): 281 | if isinstance(x, torch.Tensor): 282 | if python_value: 283 | return x.detach().cpu().numpy().tolist() 284 | else: 285 | return x.detach().cpu().numpy() 286 | else: 287 | return x 288 | 289 | def get_val(x): 290 | x_val = x.mean().item() 291 | 292 | return x_val 293 | 294 | def requires_grad(model, flag=True): 295 | for p in model.parameters(): 296 | p.requires_grad = flag 297 | 298 | def count_parameters(model): 299 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 300 | -------------------------------------------------------------------------------- /src/ddpm_ddim/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/diffusion-pytorch/3c00400219e799ffe6554c3128743cc029f09507/src/ddpm_ddim/.DS_Store -------------------------------------------------------------------------------- /src/ddpm_ddim/dataset/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/diffusion-pytorch/3c00400219e799ffe6554c3128743cc029f09507/src/ddpm_ddim/dataset/.DS_Store -------------------------------------------------------------------------------- /src/ddpm_ddim/dataset/cat/flickr_cat_000008.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/diffusion-pytorch/3c00400219e799ffe6554c3128743cc029f09507/src/ddpm_ddim/dataset/cat/flickr_cat_000008.png -------------------------------------------------------------------------------- /src/ddpm_ddim/dataset/cat/flickr_cat_000011.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/diffusion-pytorch/3c00400219e799ffe6554c3128743cc029f09507/src/ddpm_ddim/dataset/cat/flickr_cat_000011.png -------------------------------------------------------------------------------- /src/ddpm_ddim/dataset/cat/flickr_cat_000016.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/diffusion-pytorch/3c00400219e799ffe6554c3128743cc029f09507/src/ddpm_ddim/dataset/cat/flickr_cat_000016.png -------------------------------------------------------------------------------- /src/ddpm_ddim/dataset/cat/flickr_cat_000056.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/diffusion-pytorch/3c00400219e799ffe6554c3128743cc029f09507/src/ddpm_ddim/dataset/cat/flickr_cat_000056.png -------------------------------------------------------------------------------- /src/ddpm_ddim/dataset/cat/flickr_cat_000076.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/diffusion-pytorch/3c00400219e799ffe6554c3128743cc029f09507/src/ddpm_ddim/dataset/cat/flickr_cat_000076.png -------------------------------------------------------------------------------- /src/ddpm_ddim/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | from matplotlib import pyplot as plt 5 | from tqdm import tqdm 6 | from torch import optim 7 | from utils import * 8 | from modules import UNet, linear_beta_schedule, cosine_beta_schedule 9 | import logging 10 | from torch.utils.tensorboard import SummaryWriter 11 | import torch.nn.functional as F 12 | 13 | logging.basicConfig(format="%(asctime)s - %(levelname)s: %(message)s", level=logging.INFO, datefmt="%I:%M:%S") 14 | 15 | 16 | class Diffusion: 17 | def __init__(self, noise_steps=1000, beta_start=1e-4, beta_end=0.02, img_size=256, objective='ddpm', schedule='linear', device="cuda"): 18 | self.noise_steps = noise_steps 19 | self.beta_start = beta_start 20 | self.beta_end = beta_end 21 | self.img_size = img_size 22 | self.device = device 23 | 24 | self.objective = objective 25 | 26 | self.beta = self.prepare_noise_schedule(schedule, beta_start, beta_end).to(device) 27 | 28 | self.alpha = 1. - self.beta 29 | self.alpha_hat = torch.cumprod(self.alpha, dim=0) 30 | 31 | def prepare_noise_schedule(self, schedule, beta_start, beta_end): 32 | if schedule == 'linear': 33 | return linear_beta_schedule(self.noise_steps, beta_start, beta_end) 34 | else: 35 | return cosine_beta_schedule(self.noise_steps) 36 | 37 | def noise_images(self, x, t): 38 | sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t])[:, None, None, None] 39 | sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat[t])[:, None, None, None] 40 | z = torch.randn_like(x) 41 | return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * z, z 42 | 43 | def sample_timesteps(self, n): 44 | t = torch.randint(low=1, high=self.noise_steps, size=(n,)) 45 | return t 46 | 47 | def tensor_to_image(self, x): 48 | x = (x.clamp(-1, 1) + 1) / 2 49 | x = (x * 255).type(torch.uint8) 50 | return x 51 | 52 | def sample(self, model, n): 53 | # reverse process 54 | logging.info(f"Sampling {n} new images....") 55 | model.eval() 56 | with torch.no_grad(): 57 | x = torch.randn((n, 3, self.img_size, self.img_size)).to(self.device) 58 | for i in tqdm(reversed(range(1, self.noise_steps))): 59 | t = (torch.ones(n, dtype=torch.long) * i).to(self.device) 60 | 61 | alpha = self.alpha[t][:, None, None, None] 62 | beta = self.beta[t][:, None, None, None] 63 | alpha_hat = self.alpha_hat[t][:, None, None, None] 64 | alpha_hat_prev = self.alpha_hat[t-1][:, None, None, None] 65 | beta_tilde = beta * (1 - alpha_hat_prev) / (1 - alpha_hat) # similar to beta 66 | 67 | predicted_noise = model(x, t) 68 | noise = torch.randn_like(x) 69 | 70 | if self.objective == 'ddpm': 71 | predict_x0 = 0 72 | direction_point = 1 / torch.sqrt(alpha) * (x - (beta / (torch.sqrt(1 - alpha_hat))) * predicted_noise) 73 | random_noise = torch.sqrt(beta_tilde) * noise 74 | 75 | x = predict_x0 + direction_point + random_noise 76 | else: 77 | predict_x0 = torch.sqrt(alpha_hat_prev) * (x - torch.sqrt(1 - alpha_hat) * predicted_noise) / torch.sqrt(alpha_hat) 78 | direction_point = torch.sqrt(1 - alpha_hat_prev) * predicted_noise 79 | random_noise = 0 80 | 81 | x = predict_x0 + direction_point + random_noise 82 | 83 | model.train() 84 | return torch.clamp(x, -1.0, 1.0) 85 | 86 | 87 | def train(args): 88 | setup_logging(args.run_name) 89 | device = args.device 90 | dataloader = get_data(args) 91 | model = UNet(device=device).to(device) 92 | optimizer = optim.AdamW(model.parameters(), lr=args.lr) 93 | mse = nn.MSELoss() 94 | diffusion = Diffusion(img_size=args.image_size, device=device) 95 | logger = SummaryWriter(os.path.join("logs", args.run_name)) 96 | l = len(dataloader) 97 | 98 | for epoch in range(args.epochs): 99 | logging.info(f"Starting epoch {epoch}:") 100 | pbar = tqdm(dataloader) 101 | for i, images in enumerate(pbar): 102 | images = images.to(device) 103 | t = diffusion.sample_timesteps(images.shape[0]).to(device) 104 | x_t, noise = diffusion.noise_images(images, t) 105 | predicted_noise = model(x_t, t) 106 | loss = mse(noise, predicted_noise) 107 | 108 | optimizer.zero_grad() 109 | loss.backward() 110 | optimizer.step() 111 | 112 | pbar.set_postfix(MSE=loss.item()) 113 | logger.add_scalar("MSE", loss.item(), global_step=epoch * l + i) 114 | 115 | sampled_images = diffusion.sample(model, n=images.shape[0]) 116 | save_images(sampled_images, os.path.join("results", args.run_name, f"{epoch}.png")) 117 | torch.save(model.state_dict(), os.path.join("models", args.run_name, f"ckpt.pt")) 118 | 119 | 120 | def launch(): 121 | import argparse 122 | parser = argparse.ArgumentParser() 123 | args = parser.parse_args() 124 | args.epochs = 100 125 | args.batch_size = 16 126 | args.image_size = 64 127 | args.objective = 'ddpm' 128 | args.schedule = 'linear' 129 | args.dataset_path = "../dataset/cat" 130 | args.device = "cuda" 131 | args.lr = 3e-4 132 | 133 | args.run_name = "diffusion_{}_{}".format(args.objective, args.schedule) 134 | train(args) 135 | 136 | 137 | if __name__ == '__main__': 138 | launch() 139 | -------------------------------------------------------------------------------- /src/ddpm_ddim/main_template.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | from matplotlib import pyplot as plt 5 | from tqdm import tqdm 6 | from torch import optim 7 | from utils import * 8 | from modules import UNet, linear_beta_schedule, cosine_beta_schedule 9 | import logging 10 | from torch.utils.tensorboard import SummaryWriter 11 | import torch.nn.functional as F 12 | 13 | logging.basicConfig(format="%(asctime)s - %(levelname)s: %(message)s", level=logging.INFO, datefmt="%I:%M:%S") 14 | 15 | 16 | class Diffusion: 17 | def __init__(self, noise_steps=1000, beta_start=1e-4, beta_end=0.02, img_size=256, objective='ddpm', schedule='linear', device="cuda"): 18 | self.noise_steps = noise_steps 19 | self.beta_start = beta_start 20 | self.beta_end = beta_end 21 | self.img_size = img_size 22 | self.device = device 23 | 24 | self.objective = objective 25 | 26 | self.beta = self.prepare_noise_schedule(schedule, beta_start, beta_end).to(device) 27 | 28 | """ 29 | Step 1. 30 | 31 | self.alpha = ? 32 | self.alpha_hat = ? 33 | 34 | """ 35 | 36 | 37 | def prepare_noise_schedule(self, schedule, beta_start, beta_end): 38 | if schedule == 'linear': 39 | return linear_beta_schedule(self.noise_steps, beta_start, beta_end) 40 | else: 41 | return cosine_beta_schedule(self.noise_steps) 42 | 43 | def sample_timesteps(self, n): 44 | """ 45 | Step 2. 46 | n개의 랜덤한 timestep을 샘플링 하세요. range = [1, self.noise_steps] 47 | 48 | :param n: int 49 | :return: [n, ] shape을 갖고있을것입니다. 50 | 51 | 주의사항: timestep이니까, 값은 int형이어야 합니다. 52 | 53 | """ 54 | return 55 | 56 | def noise_images(self, x, t): 57 | """ 58 | Step 3. 59 | forward process를 작성하세요. 60 | -> 이미지에 noise를 입히는 과정입니다. 61 | 62 | return은 노이즈를 입힌 이미지와, 입혔던 노이즈를 리턴하세요 !! 총 2개입니다. 63 | 64 | :param x: [n, 3, img_size, img_size] 65 | :param t: [n, ] 66 | :return: [n, 3, img_size, img_size], [n, 3, img_size, img_size] 67 | 68 | """ 69 | return 70 | 71 | 72 | def sample(self, model, n): 73 | """ 74 | Step 5. 마지막! 75 | reverse process를 완성하세요. 76 | 77 | :param model: Unet 78 | :param n: batch_size 79 | :return: x: [n, 3, img_size, img_size] 80 | """ 81 | logging.info(f"Sampling {n} new images....") 82 | model.eval() 83 | with torch.no_grad(): 84 | """ 85 | (1) T스텝에서 부터 denoise하는것이기때문에, 가우시안 noise를 하나 만드세요. 86 | (2) T (self.noise_steps)부터 denoise하는 구문을 만드세요. 87 | hint: T, T-1, T-2, ... , 3, 2, 1 이런식으로 t가 나와야겠죠 ? 88 | (3) t에 해당하는 alpha_t, beta_t, alpha_hat_t, alpha_hat_(t-1), beta_tilde를 만드세요. 89 | 90 | (4) (1)의 noise와 (2)의 t를 모델에 넣어서, noise를 predict하세요. 91 | (5) predict한 noise를 가지고, ddpm과 ddim sampling를 작성하세요. 92 | 93 | """ 94 | 95 | model.train() 96 | return torch.clamp(x, -1.0, 1.0) 97 | 98 | 99 | def train(args): 100 | setup_logging(args.run_name) 101 | device = args.device 102 | dataloader = get_data(args) 103 | model = UNet(device=device).to(device) 104 | optimizer = optim.AdamW(model.parameters(), lr=args.lr) 105 | diffusion = Diffusion(img_size=args.image_size, device=device) 106 | logger = SummaryWriter(os.path.join("logs", args.run_name)) 107 | l = len(dataloader) 108 | 109 | for epoch in range(args.epochs): 110 | logging.info(f"Starting epoch {epoch}:") 111 | pbar = tqdm(dataloader) 112 | for i, images in enumerate(pbar): 113 | images = images.to(device) 114 | """ 115 | Step 4. 116 | 학습코드를 작성해보세요. 117 | 다음 hint를 참고하여 작성하면됩니다. 118 | 119 | hint: 120 | (1) timestep을 샘플링 하세요. 121 | (2) 해당 timestep t에 대응되는 노이즈 입힌 이미지를 만드세요. 122 | (3) 모델에 넣어서, 노이즈를 predict 하세요. 123 | (4) 적절한 loss를 선택하세요. (L1 or L2) 124 | """ 125 | 126 | optimizer.zero_grad() 127 | loss.backward() 128 | optimizer.step() 129 | 130 | pbar.set_postfix(Loss=loss.item()) 131 | logger.add_scalar("diffusion loss", loss.item(), global_step=epoch * l + i) 132 | 133 | sampled_images = diffusion.sample(model, n=images.shape[0]) 134 | save_images(sampled_images, os.path.join("results", args.run_name, f"{epoch}.png")) 135 | torch.save(model.state_dict(), os.path.join("models", args.run_name, f"ckpt.pt")) 136 | 137 | 138 | def launch(): 139 | import argparse 140 | parser = argparse.ArgumentParser() 141 | args = parser.parse_args() 142 | args.epochs = 100 143 | args.batch_size = 16 144 | args.image_size = 64 145 | args.objective = 'ddpm' 146 | args.schedule = 'linear' 147 | args.dataset_path = "../dataset/cat" 148 | args.device = "cpu" 149 | args.lr = 3e-4 150 | 151 | args.run_name = "diffusion_{}_{}".format(args.objective, args.schedule) 152 | train(args) 153 | 154 | 155 | if __name__ == '__main__': 156 | launch() 157 | -------------------------------------------------------------------------------- /src/ddpm_ddim/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | 6 | def linear_beta_schedule(timesteps, beta_start, beta_end): 7 | scale = 1000 / timesteps 8 | beta_start = scale * beta_start 9 | beta_end = scale * beta_end 10 | 11 | return torch.linspace(beta_start, beta_end, timesteps) 12 | 13 | def cosine_beta_schedule(timesteps, s = 0.008): 14 | """ 15 | cosine schedule 16 | as proposed in https://openreview.net/forum?id=-NEXDKk8gZ 17 | """ 18 | steps = timesteps + 1 19 | x = torch.linspace(0, timesteps, steps) 20 | 21 | alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2 22 | alphas_cumprod = alphas_cumprod / alphas_cumprod[0] 23 | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) 24 | 25 | return torch.clamp(betas, 0, 0.999) 26 | 27 | class EMA: 28 | def __init__(self, beta): 29 | super().__init__() 30 | self.beta = beta 31 | self.step = 0 32 | 33 | def update_model_average(self, ma_model, current_model): 34 | for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()): 35 | old_weight, up_weight = ma_params.data, current_params.data 36 | ma_params.data = self.update_average(old_weight, up_weight) 37 | 38 | def update_average(self, old, new): 39 | if old is None: 40 | return new 41 | return old * self.beta + (1 - self.beta) * new 42 | 43 | def step_ema(self, ema_model, model, step_start_ema=2000): 44 | if self.step < step_start_ema: 45 | self.reset_parameters(ema_model, model) 46 | self.step += 1 47 | return 48 | self.update_model_average(ema_model, model) 49 | self.step += 1 50 | 51 | def reset_parameters(self, ema_model, model): 52 | ema_model.load_state_dict(model.state_dict()) 53 | 54 | 55 | class SelfAttention(nn.Module): 56 | def __init__(self, channels, size): 57 | super(SelfAttention, self).__init__() 58 | self.channels = channels 59 | self.size = size 60 | self.mha = nn.MultiheadAttention(channels, 4, batch_first=True) 61 | self.ln = nn.LayerNorm([channels]) 62 | self.ff_self = nn.Sequential( 63 | nn.LayerNorm([channels]), 64 | nn.Linear(channels, channels), 65 | nn.GELU(), 66 | nn.Linear(channels, channels), 67 | ) 68 | 69 | def forward(self, x): 70 | x = x.view(-1, self.channels, self.size * self.size).swapaxes(1, 2) 71 | x_ln = self.ln(x) 72 | attention_value, _ = self.mha(x_ln, x_ln, x_ln) 73 | attention_value = attention_value + x 74 | attention_value = self.ff_self(attention_value) + attention_value 75 | return attention_value.swapaxes(2, 1).view(-1, self.channels, self.size, self.size) 76 | 77 | 78 | class DoubleConv(nn.Module): 79 | def __init__(self, in_channels, out_channels, mid_channels=None, residual=False): 80 | super().__init__() 81 | self.residual = residual 82 | if not mid_channels: 83 | mid_channels = out_channels 84 | self.double_conv = nn.Sequential( 85 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False), 86 | nn.GroupNorm(1, mid_channels), 87 | nn.GELU(), 88 | nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False), 89 | nn.GroupNorm(1, out_channels), 90 | ) 91 | 92 | def forward(self, x): 93 | if self.residual: 94 | return F.gelu(x + self.double_conv(x)) 95 | else: 96 | return self.double_conv(x) 97 | 98 | 99 | class Down(nn.Module): 100 | def __init__(self, in_channels, out_channels, emb_dim=256): 101 | super().__init__() 102 | self.maxpool_conv = nn.Sequential( 103 | nn.MaxPool2d(2), 104 | DoubleConv(in_channels, in_channels, residual=True), 105 | DoubleConv(in_channels, out_channels), 106 | ) 107 | 108 | self.emb_layer = nn.Sequential( 109 | nn.SiLU(), 110 | nn.Linear( 111 | emb_dim, 112 | out_channels 113 | ), 114 | ) 115 | 116 | def forward(self, x, t): 117 | x = self.maxpool_conv(x) 118 | emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1]) 119 | return x + emb 120 | 121 | 122 | class Up(nn.Module): 123 | def __init__(self, in_channels, out_channels, emb_dim=256): 124 | super().__init__() 125 | 126 | self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) 127 | self.conv = nn.Sequential( 128 | DoubleConv(in_channels, in_channels, residual=True), 129 | DoubleConv(in_channels, out_channels, in_channels // 2), 130 | ) 131 | 132 | self.emb_layer = nn.Sequential( 133 | nn.SiLU(), 134 | nn.Linear( 135 | emb_dim, 136 | out_channels 137 | ), 138 | ) 139 | 140 | def forward(self, x, skip_x, t): 141 | x = self.up(x) 142 | x = torch.cat([skip_x, x], dim=1) 143 | x = self.conv(x) 144 | emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1]) 145 | return x + emb 146 | 147 | 148 | class UNet(nn.Module): 149 | def __init__(self, c_in=3, c_out=3, time_dim=256, device="cuda"): 150 | super().__init__() 151 | self.device = device 152 | self.time_dim = time_dim 153 | self.inc = DoubleConv(c_in, 64) 154 | self.down1 = Down(64, 128) 155 | self.sa1 = SelfAttention(128, 32) 156 | self.down2 = Down(128, 256) 157 | self.sa2 = SelfAttention(256, 16) 158 | self.down3 = Down(256, 256) 159 | self.sa3 = SelfAttention(256, 8) 160 | 161 | self.bot1 = DoubleConv(256, 512) 162 | self.bot2 = DoubleConv(512, 512) 163 | self.bot3 = DoubleConv(512, 256) 164 | 165 | self.up1 = Up(512, 128) 166 | self.sa4 = SelfAttention(128, 16) 167 | self.up2 = Up(256, 64) 168 | self.sa5 = SelfAttention(64, 32) 169 | self.up3 = Up(128, 64) 170 | self.sa6 = SelfAttention(64, 64) 171 | self.outc = nn.Conv2d(64, c_out, kernel_size=1) 172 | 173 | def pos_encoding(self, t, channels): 174 | inv_freq = 1.0 / (10000** (torch.arange(0, channels, 2, device=self.device).float() / channels)) 175 | pos_enc_a = torch.sin(t.repeat(1, channels // 2) * inv_freq) 176 | pos_enc_b = torch.cos(t.repeat(1, channels // 2) * inv_freq) 177 | pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1) 178 | return pos_enc 179 | 180 | def forward(self, x, t): 181 | t = t.unsqueeze(-1).type(torch.float) 182 | t = self.pos_encoding(t, self.time_dim) 183 | 184 | x1 = self.inc(x) 185 | x2 = self.down1(x1, t) 186 | x2 = self.sa1(x2) 187 | x3 = self.down2(x2, t) 188 | x3 = self.sa2(x3) 189 | x4 = self.down3(x3, t) 190 | x4 = self.sa3(x4) 191 | 192 | x4 = self.bot1(x4) 193 | x4 = self.bot2(x4) 194 | x4 = self.bot3(x4) 195 | 196 | x = self.up1(x4, x3, t) 197 | x = self.sa4(x) 198 | x = self.up2(x, x2, t) 199 | x = self.sa5(x) 200 | x = self.up3(x, x1, t) 201 | x = self.sa6(x) 202 | output = self.outc(x) 203 | return output 204 | 205 | 206 | class UNet_conditional(nn.Module): 207 | def __init__(self, c_in=3, c_out=3, time_dim=256, num_classes=None, device="cuda"): 208 | super().__init__() 209 | self.device = device 210 | self.time_dim = time_dim 211 | self.inc = DoubleConv(c_in, 64) 212 | self.down1 = Down(64, 128) 213 | self.sa1 = SelfAttention(128, 32) 214 | self.down2 = Down(128, 256) 215 | self.sa2 = SelfAttention(256, 16) 216 | self.down3 = Down(256, 256) 217 | self.sa3 = SelfAttention(256, 8) 218 | 219 | self.bot1 = DoubleConv(256, 512) 220 | self.bot2 = DoubleConv(512, 512) 221 | self.bot3 = DoubleConv(512, 256) 222 | 223 | self.up1 = Up(512, 128) 224 | self.sa4 = SelfAttention(128, 16) 225 | self.up2 = Up(256, 64) 226 | self.sa5 = SelfAttention(64, 32) 227 | self.up3 = Up(128, 64) 228 | self.sa6 = SelfAttention(64, 64) 229 | self.outc = nn.Conv2d(64, c_out, kernel_size=1) 230 | 231 | if num_classes is not None: 232 | self.label_emb = nn.Embedding(num_classes, time_dim) 233 | 234 | def pos_encoding(self, t, channels): 235 | inv_freq = 1.0 / ( 236 | 10000 237 | ** (torch.arange(0, channels, 2, device=self.device).float() / channels) 238 | ) 239 | pos_enc_a = torch.sin(t.repeat(1, channels // 2) * inv_freq) 240 | pos_enc_b = torch.cos(t.repeat(1, channels // 2) * inv_freq) 241 | pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1) 242 | return pos_enc 243 | 244 | def forward(self, x, t, y): 245 | t = t.unsqueeze(-1).type(torch.float) 246 | t = self.pos_encoding(t, self.time_dim) 247 | 248 | if y is not None: 249 | t += self.label_emb(y) 250 | 251 | x1 = self.inc(x) 252 | x2 = self.down1(x1, t) 253 | x2 = self.sa1(x2) 254 | x3 = self.down2(x2, t) 255 | x3 = self.sa2(x3) 256 | x4 = self.down3(x3, t) 257 | x4 = self.sa3(x4) 258 | 259 | x4 = self.bot1(x4) 260 | x4 = self.bot2(x4) 261 | x4 = self.bot3(x4) 262 | 263 | x = self.up1(x4, x3, t) 264 | x = self.sa4(x) 265 | x = self.up2(x, x2, t) 266 | x = self.sa5(x) 267 | x = self.up3(x, x1, t) 268 | x = self.sa6(x) 269 | output = self.outc(x) 270 | return output 271 | -------------------------------------------------------------------------------- /src/ddpm_ddim/noise.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taki0112/diffusion-pytorch/3c00400219e799ffe6554c3128743cc029f09507/src/ddpm_ddim/noise.jpg -------------------------------------------------------------------------------- /src/ddpm_ddim/noise_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision.utils import save_image 3 | from main import Diffusion 4 | from utils import get_data 5 | import argparse 6 | 7 | parser = argparse.ArgumentParser() 8 | args = parser.parse_args() 9 | args.batch_size = 1 # 5 10 | args.image_size = 64 11 | args.dataset_path = '../dataset/cat' 12 | 13 | dataloader = get_data(args) 14 | 15 | diff = Diffusion(device="cpu") 16 | 17 | image = next(iter(dataloader))[0] 18 | t = torch.Tensor([50, 100, 150, 200, 300, 600, 700, 999]).long() 19 | 20 | noised_image, _ = diff.noise_images(image, t) 21 | save_image(noised_image.add(1).mul(0.5), "noise.jpg") -------------------------------------------------------------------------------- /src/ddpm_ddim/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision 4 | from PIL import Image 5 | from matplotlib import pyplot as plt 6 | from torch.utils.data import DataLoader 7 | from torchvision import transforms 8 | from torch.utils.data import Dataset 9 | from glob import glob 10 | 11 | def plot_images(images): 12 | plt.figure(figsize=(32, 32)) 13 | plt.imshow(torch.cat([ 14 | torch.cat([i for i in images.cpu()], dim=-1), 15 | ], dim=-2).permute(1, 2, 0).cpu()) 16 | plt.show() 17 | 18 | 19 | def save_images(images, path): 20 | torchvision.utils.save_image(images, path, 21 | nrow=4, 22 | normalize=True, range=(-1, 1)) 23 | 24 | 25 | def get_data(args): 26 | dataset = ImageDataset(img_size=args.image_size, dataset_path=args.dataset_path) 27 | dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True) 28 | return dataloader 29 | 30 | class ImageDataset(Dataset): 31 | def __init__(self, img_size, dataset_path): 32 | self.train_images = self.listdir(dataset_path) 33 | 34 | transform_list = [ 35 | transforms.Resize(size=[img_size, img_size]), 36 | transforms.RandomHorizontalFlip(p=0.5), 37 | transforms.ToTensor(), # [0, 255] -> [0, 1] 38 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), # [0, 1] -> [-1, 1] 39 | ] 40 | 41 | self.transform = transforms.Compose(transform_list) 42 | 43 | def listdir(self, dir_path): 44 | extensions = ['png', 'jpg', 'jpeg', 'JPG'] 45 | file_path = [] 46 | for ext in extensions: 47 | file_path += glob(os.path.join(dir_path, '*.' + ext)) 48 | file_path.sort() 49 | return file_path 50 | 51 | def __getitem__(self, index): 52 | sample_path = self.train_images[index] 53 | img = Image.open(sample_path).convert('RGB') 54 | img = self.transform(img) 55 | 56 | 57 | return img 58 | 59 | def __len__(self): 60 | return len(self.train_images) 61 | 62 | def setup_logging(run_name): 63 | os.makedirs("models", exist_ok=True) 64 | os.makedirs("results", exist_ok=True) 65 | os.makedirs(os.path.join("models", run_name), exist_ok=True) 66 | os.makedirs(os.path.join("results", run_name), exist_ok=True) -------------------------------------------------------------------------------- /src/evaluation/clipscore.py: -------------------------------------------------------------------------------- 1 | from torchmetrics.multimodal import CLIPScore 2 | from transformers import CLIPTokenizer, CLIPTextModel, CLIPVisionModel, CLIPModel 3 | import torch 4 | from torchvision import transforms 5 | import torch.nn.functional as F 6 | import torch.nn as nn 7 | 8 | def clip_image_process(x): 9 | def denormalize(x): 10 | # [-1, 1] ~ [0, 255] 11 | x = ((x + 1) / 2 * 255).clamp(0, 255).to(torch.uint8) 12 | 13 | return x 14 | 15 | def resize(x): 16 | x = transforms.Resize(size=[224, 224], interpolation=transforms.InterpolationMode.BICUBIC, antialias=True)(x) 17 | return x 18 | 19 | def zero_to_one(x): 20 | x = x.float() / 255.0 21 | return x 22 | 23 | def norm_mean_std(x): 24 | mean = [0.48145466, 0.4578275, 0.40821073] 25 | std = [0.26862954, 0.26130258, 0.27577711] 26 | x = transforms.Normalize(mean=mean, std=std, inplace=True)(x) 27 | return x 28 | 29 | # 만약 x가 [-1, 1] 이면, denorm을 해줍니다. 30 | # x = denormalize(x) 31 | x = resize(x) 32 | x = zero_to_one(x) 33 | x = norm_mean_std(x) 34 | 35 | return x 36 | 37 | def contrastive_loss(logits, dim) : 38 | neg_ce = torch.diag(nn.functional.log_softmax(logits, dim=dim)) 39 | return -neg_ce.mean() 40 | 41 | def clip_contra_loss(img_features, txt_features, logit_scale): 42 | img_features = img_features / img_features.norm(p=2, dim=-1, keepdim=True) 43 | txt_features = txt_features / txt_features.norm(p=2, dim=-1, keepdim=True) 44 | 45 | # cosine similarity as logits 46 | logit_scale = logit_scale.exp() 47 | similarity = torch.matmul(txt_features, img_features.t()) * logit_scale 48 | 49 | caption_loss = contrastive_loss(similarity, dim=0) 50 | image_loss = contrastive_loss(similarity, dim=1) 51 | 52 | return (caption_loss + image_loss) / 2.0 # minimize 53 | 54 | def clip_score(img_features, txt_features): 55 | img_features = img_features / img_features.norm(p=2, dim=-1, keepdim=True) 56 | txt_features = txt_features / txt_features.norm(p=2, dim=-1, keepdim=True) 57 | 58 | # score = 100 * (img_features * txt_features).sum(axis=-1) 59 | # score = torch.mean(score) 60 | 61 | # 위와 같다. 62 | score = F.cosine_similarity(img_features, txt_features).mean() 63 | return score 64 | 65 | # library 66 | image = torch.randint(255, (2, 3, 224, 224)) 67 | text = ["a photo of a cat", "a photo of a cat"] 68 | version = 'openai/clip-vit-large-patch14' 69 | metric = CLIPScore(model_name_or_path=version) 70 | score = metric(image, text) 71 | print(score) 72 | 73 | """ 74 | Step 1. Model Init 75 | """ 76 | tokenizer = CLIPTokenizer.from_pretrained(version) 77 | clip_text_encoder = CLIPTextModel.from_pretrained(version) 78 | clip_image_encoder = CLIPVisionModel.from_pretrained(version) 79 | clip_model = CLIPModel.from_pretrained(version) 80 | 81 | """ 82 | Step 2. Text 83 | """ 84 | batch_encoding = tokenizer(text, truncation=True, max_length=77, padding="max_length", return_tensors="pt") 85 | # [input_ids, attention_mask] -> 둘다 [bs,77]의 shape을 갖고있습니다. 86 | # input_ids는 주어진 텍스트를 토크나이즈한것이고, mask는 어디까지만이 유효한 token인지 알려줍니다. 1=유효, 0=의미없음 87 | 88 | text_token = batch_encoding["input_ids"] 89 | t_embed = clip_text_encoder(text_token) # 이것은 clip_model.text_model(text_token)과 같다. 90 | # [last_hidden_state, pooler_output] -> [bs, 77, 768], [bs, 768] 91 | # last_hidden_state = word embedding 92 | # pooler_output = sentence embedding 93 | 94 | text_feature = clip_model.get_text_features(text_token) 95 | # pooler_output(sentence embedding) 에 Linear를 태운것 96 | # [bs, 768] 97 | 98 | """ 99 | Step 3. Image 100 | """ 101 | image = clip_image_process(image) 102 | 103 | i_embed = clip_image_encoder(image) # 이것은 clip_model.vision_model(image)과 같다. 104 | # [last_hidden_state, pooler_output] -> [bs, 256, 1024], [bs, 1024] 105 | 106 | image_feature = clip_model.get_image_features(image) 107 | # pooler_output에 Linear을 태운것 108 | 109 | print(clip_score(image_feature, text_feature)) 110 | print(clip_contra_loss(image_feature, text_feature, clip_model.logit_scale)) 111 | 112 | -------------------------------------------------------------------------------- /src/evaluation/data_loader.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from itertools import chain 3 | from PIL import Image 4 | from torch.utils import data 5 | from torchvision import transforms 6 | 7 | 8 | def listdir(dname): 9 | fnames = list(chain(*[list(Path(dname).rglob('*.' + ext)) 10 | for ext in ['png', 'jpg', 'jpeg', 'JPG']])) 11 | return fnames 12 | 13 | 14 | class DefaultDataset(data.Dataset): 15 | def __init__(self, root, transform=None): 16 | self.samples = listdir(root) 17 | self.samples.sort() 18 | self.transform = transform 19 | self.targets = None 20 | 21 | def __getitem__(self, index): 22 | fname = self.samples[index] 23 | img = Image.open(fname).convert('RGB') 24 | if self.transform is not None: 25 | img = self.transform(img) 26 | return img 27 | 28 | def __len__(self): 29 | return len(self.samples) 30 | 31 | 32 | 33 | def get_eval_loader(root, img_size=256, batch_size=32, 34 | imagenet_normalize=True, shuffle=True, 35 | num_workers=4, drop_last=False): 36 | print('Preparing DataLoader for the evaluation phase...') 37 | if imagenet_normalize: 38 | height, width = 299, 299 39 | mean = [0.485, 0.456, 0.406] 40 | std = [0.229, 0.224, 0.225] 41 | else: 42 | height, width = img_size, img_size 43 | mean = [0.5, 0.5, 0.5] 44 | std = [0.5, 0.5, 0.5] 45 | 46 | transform = transforms.Compose([ 47 | transforms.Resize([img_size, img_size]), 48 | transforms.Resize([height, width]), 49 | transforms.ToTensor(), 50 | transforms.Normalize(mean=mean, std=std) 51 | ]) 52 | 53 | dataset = DefaultDataset(root, transform=transform) 54 | return data.DataLoader(dataset=dataset, 55 | batch_size=batch_size, 56 | shuffle=shuffle, 57 | num_workers=num_workers, 58 | pin_memory=True, 59 | drop_last=drop_last) 60 | -------------------------------------------------------------------------------- /src/evaluation/fid.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | from torchvision import models 7 | from scipy import linalg 8 | from data_loader import get_eval_loader 9 | from tqdm import tqdm 10 | 11 | 12 | class InceptionV3(nn.Module): 13 | def __init__(self): 14 | super().__init__() 15 | inception = models.inception_v3(pretrained=True) 16 | self.block1 = nn.Sequential( 17 | inception.Conv2d_1a_3x3, inception.Conv2d_2a_3x3, 18 | inception.Conv2d_2b_3x3, 19 | nn.MaxPool2d(kernel_size=3, stride=2)) 20 | self.block2 = nn.Sequential( 21 | inception.Conv2d_3b_1x1, inception.Conv2d_4a_3x3, 22 | nn.MaxPool2d(kernel_size=3, stride=2)) 23 | self.block3 = nn.Sequential( 24 | inception.Mixed_5b, inception.Mixed_5c, 25 | inception.Mixed_5d, inception.Mixed_6a, 26 | inception.Mixed_6b, inception.Mixed_6c, 27 | inception.Mixed_6d, inception.Mixed_6e) 28 | self.block4 = nn.Sequential( 29 | inception.Mixed_7a, inception.Mixed_7b, 30 | inception.Mixed_7c, 31 | nn.AdaptiveAvgPool2d(output_size=(1, 1))) 32 | 33 | def forward(self, x): 34 | x = self.block1(x) 35 | x = self.block2(x) 36 | x = self.block3(x) 37 | x = self.block4(x) 38 | return x.view(x.size(0), -1) 39 | 40 | 41 | def frechet_distance(mu, cov, mu2, cov2): 42 | cc, _ = linalg.sqrtm(np.dot(cov, cov2), disp=False) 43 | dist = np.sum((mu -mu2)**2) + np.trace(cov + cov2 - 2*cc) 44 | return np.real(dist) 45 | 46 | 47 | @torch.no_grad() 48 | def calculate_fid_given_paths(paths, img_size=256, batch_size=50): 49 | print('Calculating FID given paths %s and %s...' % (paths[0], paths[1])) 50 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 51 | inception = InceptionV3().eval().to(device) 52 | loaders = [get_eval_loader(path, img_size, batch_size) for path in paths] 53 | 54 | mu, cov = [], [] 55 | for loader in loaders: 56 | actvs = [] 57 | for x in tqdm(loader, total=len(loader)): 58 | actv = inception(x.to(device)) 59 | actvs.append(actv) 60 | actvs = torch.cat(actvs, dim=0).cpu().detach().numpy() 61 | mu.append(np.mean(actvs, axis=0)) 62 | cov.append(np.cov(actvs, rowvar=False)) 63 | fid_value = frechet_distance(mu[0], cov[0], mu[1], cov[1]) 64 | return fid_value 65 | 66 | 67 | if __name__ == '__main__': 68 | parser = argparse.ArgumentParser() 69 | parser.add_argument('--paths', type=str, nargs=2, help='paths to real and fake images') 70 | parser.add_argument('--img_size', type=int, default=256, help='image resolution') 71 | parser.add_argument('--batch_size', type=int, default=64, help='batch size to use') 72 | args = parser.parse_args() 73 | fid_value = calculate_fid_given_paths(args.paths, args.img_size, args.batch_size) 74 | print('FID: ', fid_value) 75 | 76 | # python fid.py --paths PATH_REAL PATH_FAKE -------------------------------------------------------------------------------- /src/stable_diffusion/sd_main.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import torch 3 | from transformers import CLIPTextModel, CLIPTokenizer 4 | from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler 5 | from tqdm.auto import tqdm 6 | 7 | # pip install diffusers 8 | 9 | model_ckpt = "CompVis/stable-diffusion-v1-4" 10 | torch_device = "cpu" 11 | 12 | # init 13 | vae = AutoencoderKL.from_pretrained(model_ckpt, subfolder="vae") 14 | tokenizer = CLIPTokenizer.from_pretrained(model_ckpt, subfolder="tokenizer") 15 | text_encoder = CLIPTextModel.from_pretrained(model_ckpt, subfolder="text_encoder") 16 | unet = UNet2DConditionModel.from_pretrained(model_ckpt, subfolder="unet") 17 | scheduler = PNDMScheduler.from_pretrained(model_ckpt, subfolder="scheduler") 18 | 19 | # device 20 | vae.to(torch_device) 21 | text_encoder.to(torch_device) 22 | unet.to(torch_device) 23 | 24 | # parameter 25 | prompt = ["a photograph of an astronaut riding a horse"] 26 | height = 512 # default height of Stable Diffusion 27 | width = 512 # default width of Stable Diffusion 28 | num_inference_steps = 25 # Number of denoising steps 29 | guidance_scale = 7.5 # Scale for classifier-free guidance 30 | generator = torch.manual_seed(0) # Seed generator to create the inital latent noise 31 | batch_size = len(prompt) 32 | scheduler.set_timesteps(num_inference_steps) 33 | print(scheduler.timesteps) 34 | 35 | """ 36 | Step 1. 37 | Make text embeddings 38 | """ 39 | text_input = tokenizer( 40 | prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt" 41 | ) 42 | 43 | with torch.no_grad(): 44 | text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0] 45 | 46 | uncond_input = tokenizer([""] * batch_size, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt") 47 | uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0] 48 | 49 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) 50 | 51 | """ 52 | Step 2. 53 | Reverse process 54 | """ 55 | # Create random noise 56 | latents = torch.randn( 57 | (batch_size, unet.config.in_channels, height // 8, width // 8), 58 | generator=generator, 59 | ) 60 | latents = latents.to(torch_device) 61 | latents = latents * scheduler.init_noise_sigma # PNDMS = 1 62 | 63 | for t in tqdm(scheduler.timesteps): 64 | # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. 65 | latent_model_input = torch.cat([latents] * 2) 66 | 67 | latent_model_input = scheduler.scale_model_input(latent_model_input, timestep=t) 68 | 69 | # predict the noise residual 70 | with torch.no_grad(): 71 | noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample 72 | 73 | # perform guidance 74 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 75 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 76 | 77 | # compute the previous noisy sample x_t -> x_t-1 78 | latents = scheduler.step(noise_pred, t, latents).prev_sample 79 | 80 | 81 | """ 82 | Step 3. 83 | Image decoding 84 | """ 85 | latents = 1 / 0.18215 * latents 86 | with torch.no_grad(): 87 | image = vae.decode(latents).sample 88 | 89 | image = (image / 2 + 0.5).clamp(0, 1) 90 | image = image.detach().cpu().permute(0, 2, 3, 1).numpy() 91 | images = (image * 255).round().astype("uint8") 92 | pil_images = [Image.fromarray(image) for image in images] 93 | pil_images[0].save("main_results.png") -------------------------------------------------------------------------------- /src/stable_diffusion/sd_simple_main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from diffusers import StableDiffusionPipeline 3 | 4 | # pip install diffusers 5 | 6 | model_ckpt = "CompVis/stable-diffusion-v1-4" 7 | device = "mps" # cuda, cpu, mps 8 | weight_dtype = torch.float16 9 | 10 | pipe = StableDiffusionPipeline.from_pretrained(model_ckpt, torch_dtype=weight_dtype) 11 | pipe = pipe.to(device) 12 | 13 | prompt = "a photograph of an astronaut riding a horse" 14 | image = pipe(prompt).images[0] 15 | image.save("simple_results.png") --------------------------------------------------------------------------------