├── .DS_Store
├── .gitignore
├── LICENSE
├── README.md
├── assets
├── .DS_Store
└── figs
│ ├── .DS_Store
│ ├── advanced_fig.png
│ ├── basic_fig.png
│ ├── gan_fig.png
│ ├── handson_fig.png
│ └── teaser.png
├── lecture
├── 1_Basic_diffusion.pdf
├── 2_Advanced_diffusion.pdf
└── 3_HandsOn_diffusion_noans.pdf
└── src
├── .DS_Store
├── GALIP
├── .DS_Store
├── GALIP.py
├── dataset
│ ├── .DS_Store
│ └── coco_2017
│ │ ├── .DS_Store
│ │ ├── train
│ │ ├── .DS_Store
│ │ ├── image
│ │ │ ├── 000000000009.jpg
│ │ │ ├── 000000000025.jpg
│ │ │ ├── 000000000030.jpg
│ │ │ ├── 000000000034.jpg
│ │ │ ├── 000000000036.jpg
│ │ │ └── 000000000042.jpg
│ │ └── text
│ │ │ ├── 000000000009.txt
│ │ │ ├── 000000000025.txt
│ │ │ ├── 000000000030.txt
│ │ │ ├── 000000000034.txt
│ │ │ ├── 000000000036.txt
│ │ │ └── 000000000042.txt
│ │ └── val
│ │ ├── .DS_Store
│ │ ├── image
│ │ ├── 000000000139.jpg
│ │ ├── 000000000285.jpg
│ │ ├── 000000000632.jpg
│ │ ├── 000000000724.jpg
│ │ ├── 000000000776.jpg
│ │ └── 000000000785.jpg
│ │ └── text
│ │ ├── 000000000139.txt
│ │ ├── 000000000285.txt
│ │ ├── 000000000632.txt
│ │ ├── 000000000724.txt
│ │ ├── 000000000776.txt
│ │ └── 000000000785.txt
├── main.py
├── metric
│ └── fid_score.py
├── networks.py
├── ops.py
└── utils.py
├── ddpm_ddim
├── .DS_Store
├── dataset
│ ├── .DS_Store
│ └── cat
│ │ ├── flickr_cat_000008.png
│ │ ├── flickr_cat_000011.png
│ │ ├── flickr_cat_000016.png
│ │ ├── flickr_cat_000056.png
│ │ └── flickr_cat_000076.png
├── main.py
├── main_template.py
├── modules.py
├── noise.jpg
├── noise_test.py
└── utils.py
├── evaluation
├── clipscore.py
├── data_loader.py
└── fid.py
└── stable_diffusion
├── sd_main.py
└── sd_simple_main.py
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/diffusion-pytorch/3c00400219e799ffe6554c3128743cc029f09507/.DS_Store
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Junho Kim
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # diffusion-pytorch
2 | #### 이화여대 강의자료입니다. 사용시 citation 부탁드립니다. :)
3 | #### Teaching materials from Ewha Womans University. Please cite the link when used. :)
4 |
5 |
6 |
7 |

8 |
9 |
10 | ## Youtube (Korean)
11 | * [The recipe of GANs](https://www.youtube.com/watch?v=vZdEGcLU_8U)
12 | * [The basic diffusion](https://www.youtube.com/watch?v=jaPPALsUZo8)
13 | * [The advanced diffusion](https://www.youtube.com/watch?v=Z8WWriIh1PU)
14 |
15 | ## Author
16 | [Junho Kim](http://bit.ly/jhkim_resume)
17 |
18 | ---
19 | ## Summary of GANs
20 |
21 |

22 |
23 |
24 | ---
25 |
26 | ## Basic diffusion (Theory)
27 | * DDPM, DDIM
28 | * Classifier guidance
29 | * Diffusion + GAN (DDGAN)
30 |
31 |
32 |

33 |
34 |
35 | ---
36 | ## Advanced diffusion (Theory)
37 | * Stable diffusion, GALIP
38 | * Evaluation
39 | * Editing
40 |
41 |
42 |

43 |
44 |
45 | ---
46 | ## Hands-on diffusion (Implementation)
47 | * DDPM, DDIM
48 | * How to use the SD ?
49 | * How to evaluate ?
50 |
51 |
52 |

53 |
54 |
55 | ---
56 | ### Recommended code
57 | * [pytorch & tensorflow code template](https://github.com/taki0112/tf-torch-template)
58 | * [Stylegan2-pytorch](https://github.com/taki0112/stylegan2-pytorch)
59 | * [GALIP-pytorch](https://github.com/taki0112/diffusion-pytorch/tree/main/src/GALIP)
60 | * [DDGAN-tensorflow](https://github.com/taki0112/denoising-diffusion-gan-Tensorflow)
61 |
--------------------------------------------------------------------------------
/assets/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/diffusion-pytorch/3c00400219e799ffe6554c3128743cc029f09507/assets/.DS_Store
--------------------------------------------------------------------------------
/assets/figs/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/diffusion-pytorch/3c00400219e799ffe6554c3128743cc029f09507/assets/figs/.DS_Store
--------------------------------------------------------------------------------
/assets/figs/advanced_fig.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/diffusion-pytorch/3c00400219e799ffe6554c3128743cc029f09507/assets/figs/advanced_fig.png
--------------------------------------------------------------------------------
/assets/figs/basic_fig.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/diffusion-pytorch/3c00400219e799ffe6554c3128743cc029f09507/assets/figs/basic_fig.png
--------------------------------------------------------------------------------
/assets/figs/gan_fig.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/diffusion-pytorch/3c00400219e799ffe6554c3128743cc029f09507/assets/figs/gan_fig.png
--------------------------------------------------------------------------------
/assets/figs/handson_fig.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/diffusion-pytorch/3c00400219e799ffe6554c3128743cc029f09507/assets/figs/handson_fig.png
--------------------------------------------------------------------------------
/assets/figs/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/diffusion-pytorch/3c00400219e799ffe6554c3128743cc029f09507/assets/figs/teaser.png
--------------------------------------------------------------------------------
/lecture/1_Basic_diffusion.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/diffusion-pytorch/3c00400219e799ffe6554c3128743cc029f09507/lecture/1_Basic_diffusion.pdf
--------------------------------------------------------------------------------
/lecture/2_Advanced_diffusion.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/diffusion-pytorch/3c00400219e799ffe6554c3128743cc029f09507/lecture/2_Advanced_diffusion.pdf
--------------------------------------------------------------------------------
/lecture/3_HandsOn_diffusion_noans.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/diffusion-pytorch/3c00400219e799ffe6554c3128743cc029f09507/lecture/3_HandsOn_diffusion_noans.pdf
--------------------------------------------------------------------------------
/src/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/diffusion-pytorch/3c00400219e799ffe6554c3128743cc029f09507/src/.DS_Store
--------------------------------------------------------------------------------
/src/GALIP/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/diffusion-pytorch/3c00400219e799ffe6554c3128743cc029f09507/src/GALIP/.DS_Store
--------------------------------------------------------------------------------
/src/GALIP/GALIP.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data
2 | from ops import *
3 | from utils import *
4 | import time
5 | from torch.utils.tensorboard import SummaryWriter
6 | import numpy as np
7 | import torchvision
8 | from functools import partial
9 | import torch.nn.functional as F
10 |
11 | print = partial(print, flush=True)
12 |
13 | from metric.fid_score import InceptionV3, calculate_fid_t2i
14 | from networks import *
15 |
16 | def run_fn(rank, args, world_size):
17 | device = torch.device('cuda', rank)
18 | torch.backends.cudnn.benchmark = True
19 |
20 | model = GALIP(args, world_size)
21 | model.build_model(rank, device)
22 | model.train_model(rank, device)
23 |
24 | class GALIP():
25 | def __init__(self, args, NUM_GPUS):
26 | super(GALIP, self).__init__()
27 |
28 | """ Model """
29 | self.model_name = 'GALIP'
30 | self.phase = args['phase']
31 | self.NUM_GPUS = NUM_GPUS
32 |
33 |
34 | """ Training parameters """
35 | self.img_size = args['img_size']
36 | self.batch_size = args['batch_size']
37 | self.global_batch_size = self.batch_size * self.NUM_GPUS
38 | self.epoch = args['epoch']
39 | if self.epoch != 0:
40 | self.iteration = None
41 | else:
42 | self.iteration = args['iteration']
43 | self.mixed_flag = args['mixed_flag']
44 | self.growth_interval = 2000
45 | self.scaler_min = 64
46 |
47 | """ Network parameters """
48 | self.style_dim = 100
49 | self.g_lr = args['g_lr']
50 | self.d_lr = args['d_lr']
51 |
52 | """ Print parameters """
53 | self.print_freq = args['print_freq']
54 | self.save_freq = args['save_freq']
55 | self.log_template = 'step [{}/{}]: elapsed: {:.2f}s, BEST_FID: {:.2f}'
56 |
57 | """ Dataset Path """
58 | self.dataset_name = args['dataset']
59 | self.val_dataset_name = self.dataset_name + '_val'
60 | dataset_path = './dataset'
61 | self.dataset_path = os.path.join(dataset_path, self.dataset_name)
62 | self.val_dataset_path = os.path.join(dataset_path, self.val_dataset_name)
63 |
64 | """ Directory """
65 | self.checkpoint_dir = args['checkpoint_dir']
66 | self.result_dir = args['result_dir']
67 | self.log_dir = args['log_dir']
68 | self.sample_dir = args['sample_dir']
69 |
70 | self.sample_dir = os.path.join(self.sample_dir, self.model_dir)
71 | check_folder(self.sample_dir)
72 | self.checkpoint_dir = os.path.join(self.checkpoint_dir, self.model_dir)
73 | check_folder(self.checkpoint_dir)
74 | self.log_dir = os.path.join(self.log_dir, self.model_dir)
75 | check_folder(self.log_dir)
76 |
77 | ##################################################################################
78 | # Model
79 | ##################################################################################
80 | def build_model(self, rank, device):
81 | """ Init process """
82 | build_init_procss(rank, world_size=self.NUM_GPUS, device=device)
83 |
84 | """ Dataset Load """
85 | dataset = ImageTextDataset(dataset_path=self.dataset_path, img_size=self.img_size)
86 | self.dataset_num = dataset.__len__()
87 | self.iteration = self.epoch * self.dataset_num // self.global_batch_size
88 | loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, num_workers=4,
89 | sampler=distributed_sampler(dataset, rank=rank, num_replicas=self.NUM_GPUS, shuffle=True),
90 | drop_last=True, pin_memory=True)
91 | self.dataset_iter = infinite_iterator(loader)
92 |
93 | """ For FID """
94 | self.fid_dataset = ImageTextDataset(dataset_path=self.val_dataset_path, img_size=299, imagenet_normalization=True)
95 | self.fid_loader = torch.utils.data.DataLoader(self.fid_dataset, batch_size=5, num_workers=4,
96 | sampler=distributed_sampler(dataset, rank=rank, num_replicas=self.NUM_GPUS, shuffle=False),
97 | drop_last=False, pin_memory=True)
98 | self.inception = InceptionV3(mixed_precision=self.mixed_flag).to(device)
99 | requires_grad(self.inception, False)
100 |
101 | """ Pretrain Model Load """
102 | self.clip = clip.load('ViT-B/32')[0].eval().to(device)
103 | self.clip_img = CLIP_IMG_ENCODER(self.clip).to(device)
104 | self.clip_text = CLIP_TXT_ENCODER(self.clip).to(device)
105 |
106 | requires_grad(self.clip_img, False)
107 | requires_grad(self.clip_text, False)
108 |
109 | self.clip_img.eval()
110 | self.clip_text.eval()
111 |
112 |
113 | """ Network """
114 | if self.mixed_flag:
115 | self.scaler_G = torch.cuda.amp.GradScaler(growth_interval=self.growth_interval)
116 | self.scaler_D = torch.cuda.amp.GradScaler(growth_interval=self.growth_interval)
117 | else:
118 | self.scaler_G = None
119 | self.scaler_D = None
120 | self.generator = NetG(imsize=self.img_size, CLIP=self.clip, nz=self.style_dim, mixed_precision=self.mixed_flag).to(device)
121 | self.discriminator = NetD(imsize=self.img_size, mixed_precision=self.mixed_flag).to(device)
122 | self.predictor = NetC(mixed_precision=self.mixed_flag).to(device)
123 |
124 |
125 | """ Optimizer """
126 | D_params = list(self.discriminator.parameters()) + list(self.predictor.parameters())
127 | self.g_optimizer = torch.optim.Adam(self.generator.parameters(), lr=self.g_lr, betas=(0.0, 0.9), eps=1e-08)
128 | self.d_optimizer = torch.optim.Adam(D_params, lr=self.d_lr, betas=(0.0, 0.9), eps=1e-08)
129 |
130 | """ Distributed Learning """
131 | self.generator = dataparallel_and_sync(self.generator, rank)
132 | self.discriminator = dataparallel_and_sync(self.discriminator, rank)
133 | self.predictor = dataparallel_and_sync(self.predictor, rank)
134 |
135 | """ Checkpoint """
136 | self.ckpt_dict= {
137 | 'generator': self.generator.state_dict(),
138 | 'discriminator': self.discriminator.state_dict(),
139 | 'predictor' : self.predictor.state_dict(),
140 | 'g_optimizer': self.g_optimizer.state_dict(),
141 | 'd_optimizer':self.d_optimizer.state_dict()
142 | },
143 |
144 | latest_ckpt_name, start_iter = find_latest_ckpt(self.checkpoint_dir)
145 | if latest_ckpt_name is not None:
146 | if rank == 0:
147 | print('Latest checkpoint restored!! ', latest_ckpt_name)
148 | print('start iteration : ', start_iter)
149 | self.start_iteration = start_iter
150 |
151 | latest_ckpt = os.path.join(self.checkpoint_dir, latest_ckpt_name)
152 | ckpt = torch.load(latest_ckpt, map_location=device)
153 |
154 | self.generator.load_state_dict(ckpt["generator"])
155 | self.discriminator.load_state_dict(ckpt['discriminator'])
156 | self.predictor.load_state_dict(ckpt['predictor'])
157 | self.g_optimizer.load_state_dict(ckpt["g_optimizer"])
158 | self.d_optimizer.load_state_dict(ckpt["d_optimizer"])
159 |
160 | else:
161 | if rank == 0:
162 | print('Not restoring from saved checkpoint')
163 | self.start_iteration = 0
164 |
165 | def g_train_step(self, real_img, tokens, device=torch.device('cuda')):
166 | self.generator.train()
167 | self.discriminator.train()
168 | self.predictor.train()
169 |
170 | # step 0: pre-process
171 | with torch.cuda.amp.autocast() if self.mixed_flag else dummy_context_mgr() as mpc:
172 | with torch.no_grad():
173 | sent_emb, word_emb = self.clip_text(tokens) # [bs, 512], [bs, 77, 512]
174 | word_emb = word_emb.detach()
175 | sent_emb = sent_emb.detach()
176 |
177 | # synthesize fake images
178 | noise = torch.randn([self.batch_size, self.style_dim]).to(device)
179 | fake_img = self.generator(noise, sent_emb)
180 | CLIP_fake, fake_emb = self.clip_img(fake_img)
181 |
182 | # loss
183 | fake_feats = self.discriminator(CLIP_fake)
184 | output = self.predictor(fake_feats, sent_emb)
185 | text_img_sim = torch.cosine_similarity(fake_emb, sent_emb).mean()
186 | loss = -output.mean() - 4.0 * text_img_sim
187 |
188 | apply_gradients(loss, self.g_optimizer, self.mixed_flag, self.scaler_G, self.scaler_min)
189 |
190 | return loss, sent_emb
191 |
192 | def d_train_step(self, real_img, tokens, device=torch.device('cuda')):
193 | self.generator.train()
194 | self.discriminator.train()
195 | self.predictor.train()
196 |
197 | # step 0: pre-process
198 | with torch.cuda.amp.autocast() if self.mixed_flag else dummy_context_mgr() as mpc:
199 | with torch.no_grad():
200 | sent_emb, word_emb = self.clip_text(tokens) # [bs, 512], [bs, 77, 512]
201 | word_emb = word_emb.detach()
202 | sent_emb = sent_emb.detach()
203 |
204 |
205 | # loss
206 | real_img = real_img.requires_grad_()
207 | sent_emb = sent_emb.requires_grad_()
208 | word_emb = word_emb.requires_grad_()
209 |
210 | # predict real
211 | CLIP_real, real_emb = self.clip_img(real_img) # [bs, 3, 768, 7, 7], [bs, 512]
212 | real_feats = self.discriminator(CLIP_real) # [bs, 512, 7, 7]
213 | pred_real, errD_real = predict_loss(self.predictor, real_feats, sent_emb, negtive=False)
214 |
215 | # predict mismatch
216 | mis_sent_emb = torch.cat((sent_emb[1:], sent_emb[0:1]), dim=0).detach()
217 | _, errD_mis = predict_loss(self.predictor, real_feats, mis_sent_emb, negtive=True)
218 |
219 | # synthesize fake images
220 | noise = torch.randn([self.batch_size, self.style_dim]).to(device)
221 | fake_img = self.generator(noise, sent_emb)
222 | CLIP_fake, fake_emb = self.clip_img(fake_img)
223 | fake_feats = self.discriminator(CLIP_fake.detach())
224 | _, errD_fake = predict_loss(self.predictor, fake_feats, sent_emb, negtive=True)
225 |
226 | if self.mixed_flag:
227 | errD_MAGP = MA_GP_MP(CLIP_real, sent_emb, pred_real, self.scaler_D)
228 | else:
229 | errD_MAGP = MA_GP_FP32(CLIP_real, sent_emb, pred_real)
230 |
231 | with torch.cuda.amp.autocast() if self.mixed_flag else dummy_context_mgr() as mpc:
232 | loss = errD_real + (errD_fake + errD_mis) / 2.0 + errD_MAGP
233 |
234 | apply_gradients(loss, self.d_optimizer, self.mixed_flag, self.scaler_D, self.scaler_min)
235 |
236 | return loss
237 |
238 | def train_model(self, rank, device):
239 | start_time = time.time()
240 | fid_start_time = time.time()
241 |
242 | # setup tensorboards
243 | train_summary_writer = SummaryWriter(self.log_dir)
244 |
245 | # start training
246 | if rank == 0:
247 | print()
248 | print(self.dataset_path)
249 | print("Dataset number : ", self.dataset_num)
250 | print("GPUs : ", self.NUM_GPUS)
251 | print("Each batch size : ", self.batch_size)
252 | print("Global batch size : ", self.global_batch_size)
253 | print("Target image size : ", self.img_size)
254 | print("Print frequency : ", self.print_freq)
255 | print("Save frequency : ", self.save_freq)
256 | print("PyTorch Version :", torch.__version__)
257 | print('max_steps: {}'.format(self.iteration))
258 | print()
259 | losses = {'g_loss': 0.0, 'd_loss': 0.0}
260 |
261 | fid_dict = {'metric/fid': 0.0, 'metric/best_fid': 0.0, 'metric/best_fid_iter': 0}
262 | fid = 0
263 | best_fid = 1000
264 | best_fid_iter = 0
265 |
266 | for idx in range(self.start_iteration, self.iteration):
267 | iter_start_time = time.time()
268 |
269 | image, tokens, text = next(self.dataset_iter)
270 | image = image.to(device)
271 | tokens = tokens.to(device)
272 | # text = text.to(device)
273 |
274 | if idx == 0:
275 | if rank == 0:
276 | print("count params")
277 | g_params = count_parameters(self.generator)
278 | d_params = count_parameters(self.discriminator) + count_parameters(self.predictor)
279 | g_B, g_M = convert_to_billion_and_million(g_params)
280 | d_B, d_M = convert_to_billion_and_million(d_params)
281 |
282 | t_B = g_B + d_B
283 | t_M = g_M + d_M
284 |
285 | print("G network parameters : {}B, {}M".format(g_B, g_M))
286 | print("D network parameters : {}B, {}M".format(d_B, d_M))
287 | print("Total network parameters : {}B, {}M".format(t_B, t_M))
288 | print()
289 |
290 | loss = self.d_train_step(image, tokens, device=device)
291 |
292 | losses['d_loss'] = loss
293 |
294 | loss, text_embed = self.g_train_step(image, tokens, device)
295 | losses['g_loss'] = loss
296 |
297 | losses = reduce_loss_dict(losses)
298 | losses = dict_to_numpy(losses, python_value=True)
299 |
300 | if np.mod(idx, self.print_freq) == 0 or idx == self.iteration - 1 :
301 | if rank == 0:
302 | print("calculate fid ...")
303 | fid_start_time = time.time()
304 |
305 | fid = calculate_fid_t2i(self.fid_loader, self.generator, self.inception, self.clip_text, self.val_dataset_name,
306 | device=device, latent_dim=self.style_dim)
307 |
308 | if rank == 0:
309 | fid_end_time = time.time()
310 | fid_elapsed = fid_end_time - fid_start_time
311 | print("calculate fid finish: {:.2f}s".format(fid_elapsed))
312 | if fid < best_fid:
313 | print("BEST FID UPDATED")
314 | best_fid = fid
315 | best_fid_iter = idx
316 | self.torch_save(idx, fid)
317 |
318 | fid_dict['metric/best_fid'] = best_fid
319 | fid_dict['metric/best_fid_iter'] = best_fid_iter
320 | fid_dict['metric/fid'] = fid
321 |
322 |
323 | if rank == 0:
324 | # save to tensorboard
325 |
326 | for k, v in losses.items():
327 | train_summary_writer.add_scalar(k, v, global_step=idx)
328 |
329 | if np.mod(idx, self.print_freq) == 0 or idx == self.iteration - 1:
330 | train_summary_writer.add_scalar('fid', fid, global_step=idx)
331 |
332 | if np.mod(idx + 1, self.print_freq) == 0:
333 | with torch.no_grad():
334 | batch_size = text_embed.shape[0]
335 |
336 | noise = torch.randn([batch_size, self.style_dim]).to(device)
337 | self.generator.eval()
338 | fake_img = self.generator(noise, text_embed)
339 | fake_img = torch.clamp(fake_img, -1.0, 1.0)
340 |
341 | partial_size = int(batch_size ** 0.5)
342 |
343 | # resize
344 | fake_img = F.interpolate(fake_img, size=256, mode='bicubic', align_corners=True)
345 | torchvision.utils.save_image(fake_img, './{}/fake_{:06d}.png'.format(self.sample_dir, idx + 1),
346 | nrow=partial_size,
347 | normalize=True, range=(-1, 1))
348 | text_path = './{}/fake_{:06d}.txt'.format(self.sample_dir, idx+1)
349 | with open(text_path, 'w') as f:
350 | f.write('\n'.join(text))
351 | # normalize = set to the range (0, 1) by range(min, max)
352 |
353 | elapsed = time.time() - iter_start_time
354 | print(self.log_template.format(idx, self.iteration, elapsed, best_fid))
355 |
356 | dist.barrier()
357 |
358 | if rank == 0:
359 | # save model for final step
360 | self.torch_save(self.iteration, fid)
361 |
362 | print("LAST FID: ", fid)
363 | print("BEST FID: {}, {}".format(best_fid, best_fid_iter))
364 | print("Total train time: %4.4f" % (time.time() - start_time))
365 |
366 | dist.barrier()
367 |
368 | def torch_save(self, idx, fid=0):
369 | fid_int = int(fid)
370 | torch.save(
371 | self.ckpt_dict,
372 | os.path.join(self.checkpoint_dir, 'iter_{}_fid_{}.pt'.format(idx, fid_int))
373 | )
374 |
375 | @property
376 | def model_dir(self):
377 | return "{}_{}_{}_bs{}_{}GPUs_Mixed{}".format(self.model_name, self.dataset_name, self.img_size, self.batch_size, self.NUM_GPUS, self.mixed_flag)
--------------------------------------------------------------------------------
/src/GALIP/dataset/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/diffusion-pytorch/3c00400219e799ffe6554c3128743cc029f09507/src/GALIP/dataset/.DS_Store
--------------------------------------------------------------------------------
/src/GALIP/dataset/coco_2017/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/diffusion-pytorch/3c00400219e799ffe6554c3128743cc029f09507/src/GALIP/dataset/coco_2017/.DS_Store
--------------------------------------------------------------------------------
/src/GALIP/dataset/coco_2017/train/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/diffusion-pytorch/3c00400219e799ffe6554c3128743cc029f09507/src/GALIP/dataset/coco_2017/train/.DS_Store
--------------------------------------------------------------------------------
/src/GALIP/dataset/coco_2017/train/image/000000000009.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/diffusion-pytorch/3c00400219e799ffe6554c3128743cc029f09507/src/GALIP/dataset/coco_2017/train/image/000000000009.jpg
--------------------------------------------------------------------------------
/src/GALIP/dataset/coco_2017/train/image/000000000025.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/diffusion-pytorch/3c00400219e799ffe6554c3128743cc029f09507/src/GALIP/dataset/coco_2017/train/image/000000000025.jpg
--------------------------------------------------------------------------------
/src/GALIP/dataset/coco_2017/train/image/000000000030.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/diffusion-pytorch/3c00400219e799ffe6554c3128743cc029f09507/src/GALIP/dataset/coco_2017/train/image/000000000030.jpg
--------------------------------------------------------------------------------
/src/GALIP/dataset/coco_2017/train/image/000000000034.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/diffusion-pytorch/3c00400219e799ffe6554c3128743cc029f09507/src/GALIP/dataset/coco_2017/train/image/000000000034.jpg
--------------------------------------------------------------------------------
/src/GALIP/dataset/coco_2017/train/image/000000000036.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/diffusion-pytorch/3c00400219e799ffe6554c3128743cc029f09507/src/GALIP/dataset/coco_2017/train/image/000000000036.jpg
--------------------------------------------------------------------------------
/src/GALIP/dataset/coco_2017/train/image/000000000042.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/diffusion-pytorch/3c00400219e799ffe6554c3128743cc029f09507/src/GALIP/dataset/coco_2017/train/image/000000000042.jpg
--------------------------------------------------------------------------------
/src/GALIP/dataset/coco_2017/train/text/000000000009.txt:
--------------------------------------------------------------------------------
1 | Closeup of bins of food that include broccoli and bread.
2 | A meal is presented in brightly colored plastic trays.
3 | there are containers filled with different kinds of foods
4 | Colorful dishes holding meat, vegetables, fruit, and bread.
5 | A bunch of trays that have different food.
6 |
--------------------------------------------------------------------------------
/src/GALIP/dataset/coco_2017/train/text/000000000025.txt:
--------------------------------------------------------------------------------
1 | A giraffe eating food from the top of the tree.
2 | A giraffe standing up nearby a tree
3 | A giraffe mother with its baby in the forest.
4 | Two giraffes standing in a tree filled area.
5 | A giraffe standing next to a forest filled with trees.
6 |
--------------------------------------------------------------------------------
/src/GALIP/dataset/coco_2017/train/text/000000000030.txt:
--------------------------------------------------------------------------------
1 | A flower vase is sitting on a porch stand.
2 | White vase with different colored flowers sitting inside of it.
3 | a white vase with many flowers on a stage
4 | A white vase filled with different colored flowers.
5 | A vase with red and white flowers outside on a sunny day.
6 |
--------------------------------------------------------------------------------
/src/GALIP/dataset/coco_2017/train/text/000000000034.txt:
--------------------------------------------------------------------------------
1 | A zebra grazing on lush green grass in a field.
2 | Zebra reaching its head down to ground where grass is.
3 | The zebra is eating grass in the sun.
4 | A lone zebra grazing in some green grass.
5 | a Zebra grazing on grass in a green open field.
6 |
--------------------------------------------------------------------------------
/src/GALIP/dataset/coco_2017/train/text/000000000036.txt:
--------------------------------------------------------------------------------
1 | Woman in swim suit holding parasol on sunny day.
2 | A woman posing for the camera, holding a pink, open umbrella and wearing a bright, floral, ruched bathing suit, by a life guard stand with lake, green trees, and a blue sky with a few clouds behind.
3 | A woman in a floral swimsuit holds a pink umbrella.
4 | A woman with an umbrella near the sea
5 | A girl in a bathing suit with a pink umbrella.
6 |
--------------------------------------------------------------------------------
/src/GALIP/dataset/coco_2017/train/text/000000000042.txt:
--------------------------------------------------------------------------------
1 | This wire metal rack holds several pairs of shoes and sandals
2 | A dog sleeping on a show rack in the shoes.
3 | Various slides and other footwear rest in a metal basket outdoors.
4 | A small dog is curled up on top of the shoes
5 | a shoe rack with some shoes and a dog sleeping on them
6 |
--------------------------------------------------------------------------------
/src/GALIP/dataset/coco_2017/val/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/diffusion-pytorch/3c00400219e799ffe6554c3128743cc029f09507/src/GALIP/dataset/coco_2017/val/.DS_Store
--------------------------------------------------------------------------------
/src/GALIP/dataset/coco_2017/val/image/000000000139.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/diffusion-pytorch/3c00400219e799ffe6554c3128743cc029f09507/src/GALIP/dataset/coco_2017/val/image/000000000139.jpg
--------------------------------------------------------------------------------
/src/GALIP/dataset/coco_2017/val/image/000000000285.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/diffusion-pytorch/3c00400219e799ffe6554c3128743cc029f09507/src/GALIP/dataset/coco_2017/val/image/000000000285.jpg
--------------------------------------------------------------------------------
/src/GALIP/dataset/coco_2017/val/image/000000000632.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/diffusion-pytorch/3c00400219e799ffe6554c3128743cc029f09507/src/GALIP/dataset/coco_2017/val/image/000000000632.jpg
--------------------------------------------------------------------------------
/src/GALIP/dataset/coco_2017/val/image/000000000724.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/diffusion-pytorch/3c00400219e799ffe6554c3128743cc029f09507/src/GALIP/dataset/coco_2017/val/image/000000000724.jpg
--------------------------------------------------------------------------------
/src/GALIP/dataset/coco_2017/val/image/000000000776.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/diffusion-pytorch/3c00400219e799ffe6554c3128743cc029f09507/src/GALIP/dataset/coco_2017/val/image/000000000776.jpg
--------------------------------------------------------------------------------
/src/GALIP/dataset/coco_2017/val/image/000000000785.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/diffusion-pytorch/3c00400219e799ffe6554c3128743cc029f09507/src/GALIP/dataset/coco_2017/val/image/000000000785.jpg
--------------------------------------------------------------------------------
/src/GALIP/dataset/coco_2017/val/text/000000000139.txt:
--------------------------------------------------------------------------------
1 | A woman stands in the dining area at the table.
2 | A room with chairs, a table, and a woman in it.
3 | A woman standing in a kitchen by a window
4 | A person standing at a table in a room.
5 | A living area with a television and a table
6 |
--------------------------------------------------------------------------------
/src/GALIP/dataset/coco_2017/val/text/000000000285.txt:
--------------------------------------------------------------------------------
1 | A big burly grizzly bear is show with grass in the background.
2 | The large brown bear has a black nose.
3 | Closeup of a brown bear sitting in a grassy area.
4 | A large bear that is sitting on grass.
5 | A close up picture of a brown bear's face.
6 |
--------------------------------------------------------------------------------
/src/GALIP/dataset/coco_2017/val/text/000000000632.txt:
--------------------------------------------------------------------------------
1 | Bedroom scene with a bookcase, blue comforter and window.
2 | A bedroom with a bookshelf full of books.
3 | This room has a bed with blue sheets and a large bookcase
4 | A bed and a mirror in a small room.
5 | a bed room with a neatly made bed a window and a book shelf
6 |
--------------------------------------------------------------------------------
/src/GALIP/dataset/coco_2017/val/text/000000000724.txt:
--------------------------------------------------------------------------------
1 | A stop sign is mounted upside-down on it's post.
2 | A stop sign that is hanging upside down.
3 | An upside down stop sign by the road.
4 | a stop sign put upside down on a metal pole
5 | A stop sign installed upside down on a street corner
6 |
--------------------------------------------------------------------------------
/src/GALIP/dataset/coco_2017/val/text/000000000776.txt:
--------------------------------------------------------------------------------
1 | Three teddy bears, each a different color, snuggling together.
2 | Three stuffed animals are sitting on a bed.
3 | three teddy bears giving each other a hug
4 | A group of three stuffed animal teddy bears.
5 | Three stuffed bears hugging and sitting on a blue pillow
6 |
--------------------------------------------------------------------------------
/src/GALIP/dataset/coco_2017/val/text/000000000785.txt:
--------------------------------------------------------------------------------
1 | A woman posing for the camera standing on skis.
2 | a woman standing on skiis while posing for the camera
3 | A woman in a red jacket skiing down a slope
4 | A young woman is skiing down the mountain slope.
5 | a person on skis makes her way through the snow
6 |
--------------------------------------------------------------------------------
/src/GALIP/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from utils import *
3 | from GALIP import run_fn
4 |
5 | """
6 | count params
7 | G network parameters : 51,830,211
8 | D network parameters : 30,806,021
9 | Total network parameters : 82,636,232
10 | """
11 |
12 | def parse_args():
13 | desc = "Pytorch implementation of GALIP"
14 | parser = argparse.ArgumentParser(description=desc)
15 | parser.add_argument('--phase', type=str, default='train', help='[train, test]')
16 | parser.add_argument('--dataset', type=str, default='coco_2017', help='dataset_name')
17 | # celeba_hq_text
18 | # coco_2017
19 | parser.add_argument('--epoch', type=int, default=3000, help='The total epoch')
20 | parser.add_argument('--iteration', type=int, default=1000000, help='The total iterations')
21 | parser.add_argument('--img_size', type=int, default=256, help='The size of image')
22 | parser.add_argument('--batch_size', type=int, default=64, help='batch sizes for each gpus')
23 | parser.add_argument('--mixed_flag', type=str2bool, default=True, help='Mixed Precision Flag')
24 | # single = 16
25 |
26 | # StyleGAN paraeter
27 | parser.add_argument("--g_lr", type=float, default=0.0001, help="g learning rate")
28 | parser.add_argument("--d_lr", type=float, default=0.0004, help="d learning rate")
29 |
30 | parser.add_argument('--print_freq', type=int, default=5000, help='The number of image_print_freq')
31 | parser.add_argument('--save_freq', type=int, default=50000, help='The number of ckpt_save_freq')
32 |
33 | parser.add_argument('--checkpoint_dir', type=str, default='checkpoint',
34 | help='Directory name to save the checkpoints')
35 | parser.add_argument('--result_dir', type=str, default='results',
36 | help='Directory name to save the generated images')
37 | parser.add_argument('--log_dir', type=str, default='logs',
38 | help='Directory name to save training logs')
39 | parser.add_argument('--sample_dir', type=str, default='samples',
40 | help='Directory name to save the samples_prev on training')
41 |
42 | return check_args(parser.parse_args())
43 |
44 |
45 | """checking arguments"""
46 | def check_args(args):
47 | # --checkpoint_dir
48 | check_folder(args.checkpoint_dir)
49 |
50 | # --result_dir
51 | check_folder(args.result_dir)
52 |
53 | # --result_dir
54 | check_folder(args.log_dir)
55 |
56 | # --sample_dir
57 | check_folder(args.sample_dir)
58 |
59 | # --batch_size
60 | try:
61 | assert args.batch_size >= 1
62 | except:
63 | print('batch size must be larger than or equal to one', flush=True)
64 |
65 | return args
66 |
67 | """main"""
68 | def main():
69 |
70 | args = vars(parse_args())
71 |
72 | # run
73 | multi_gpu_run(ddp_fn=run_fn, args=args)
74 |
75 | if __name__ == '__main__':
76 | main()
--------------------------------------------------------------------------------
/src/GALIP/metric/fid_score.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import functional as F
4 | import numpy as np
5 | from torchvision import models
6 | import torch.distributed as dist
7 | import math
8 | from tqdm import tqdm
9 | from torchvision import transforms
10 | from scipy import linalg
11 | import pickle, os
12 | from torch.nn.functional import adaptive_avg_pool2d
13 |
14 | class dummy_context_mgr():
15 | def __enter__(self):
16 | return None
17 |
18 | def __exit__(self, exc_type, exc_value, traceback):
19 | return False
20 |
21 | class GatherLayer(torch.autograd.Function):
22 | @staticmethod
23 | def forward(ctx, input):
24 | ctx.save_for_backward(input)
25 | output = [torch.zeros_like(input) for _ in range(dist.get_world_size())]
26 | dist.all_gather(output, input)
27 | return tuple(output)
28 |
29 | @staticmethod
30 | def backward(ctx, *grads):
31 | input, = ctx.saved_tensors
32 | grad_out = torch.zeros_like(input)
33 | grad_out[:] = grads[dist.get_rank()]
34 | return grad_out
35 |
36 |
37 | class InceptionV3_(nn.Module):
38 | def __init__(self):
39 | super().__init__()
40 | inception = models.inception_v3(weights='DEFAULT')
41 | # pretrained=True -> weights=Inception_V3_Weights.IMAGENET1K_V1
42 | # weights='DEFAULT' or weights='IMAGENET1K_V1'
43 | self.block1 = nn.Sequential(
44 | inception.Conv2d_1a_3x3, inception.Conv2d_2a_3x3,
45 | inception.Conv2d_2b_3x3,
46 | nn.MaxPool2d(kernel_size=3, stride=2))
47 | self.block2 = nn.Sequential(
48 | inception.Conv2d_3b_1x1, inception.Conv2d_4a_3x3,
49 | nn.MaxPool2d(kernel_size=3, stride=2))
50 | self.block3 = nn.Sequential(
51 | inception.Mixed_5b, inception.Mixed_5c,
52 | inception.Mixed_5d, inception.Mixed_6a,
53 | inception.Mixed_6b, inception.Mixed_6c,
54 | inception.Mixed_6d, inception.Mixed_6e)
55 | self.block4 = nn.Sequential(
56 | inception.Mixed_7a, inception.Mixed_7b,
57 | inception.Mixed_7c,
58 | nn.AdaptiveAvgPool2d(output_size=(1, 1)))
59 |
60 | def forward(self, x):
61 | x = self.block1(x)
62 | x = self.block2(x)
63 | x = self.block3(x)
64 | x = self.block4(x)
65 | return x.view(x.size(0), -1)
66 |
67 | class InceptionV3(nn.Module):
68 | """Pretrained InceptionV3 network returning feature maps"""
69 |
70 | # Index of default block of inception to return,
71 | # corresponds to output of final average pooling
72 | DEFAULT_BLOCK_INDEX = 3
73 |
74 | # Maps feature dimensionality to their output blocks indices
75 | BLOCK_INDEX_BY_DIM = {
76 | 64: 0, # First max pooling features
77 | 192: 1, # Second max pooling featurs
78 | 768: 2, # Pre-aux classifier features
79 | 2048: 3 # Final average pooling features
80 | }
81 |
82 | def __init__(self,
83 | mixed_precision=False,
84 | output_blocks=[DEFAULT_BLOCK_INDEX],
85 | resize_input=True,
86 | normalize_input=True,
87 | requires_grad=False):
88 | """Build pretrained InceptionV3
89 |
90 | Parameters
91 | ----------
92 | output_blocks : list of int
93 | Indices of blocks to return features of. Possible values are:
94 | - 0: corresponds to output of first max pooling
95 | - 1: corresponds to output of second max pooling
96 | - 2: corresponds to output which is fed to aux classifier
97 | - 3: corresponds to output of final average pooling
98 | resize_input : bool
99 | If true, bilinearly resizes input to width and height 299 before
100 | feeding input to model. As the network without fully connected
101 | layers is fully convolutional, it should be able to handle inputs
102 | of arbitrary size, so resizing might not be strictly needed
103 | normalize_input : bool
104 | If true, normalizes the input to the statistics the pretrained
105 | Inception network expects
106 | requires_grad : bool
107 | If true, parameters of the model require gradient. Possibly useful
108 | for finetuning the network
109 | """
110 | super(InceptionV3, self).__init__()
111 |
112 | self.resize_input = resize_input
113 | self.normalize_input = normalize_input
114 | self.output_blocks = sorted(output_blocks)
115 | self.last_needed_block = max(output_blocks)
116 |
117 | assert self.last_needed_block <= 3, \
118 | 'Last possible output block index is 3'
119 |
120 | self.blocks = nn.ModuleList()
121 |
122 | inception = models.inception_v3(pretrained=True)
123 |
124 | # Block 0: input to maxpool1
125 | block0 = [
126 | inception.Conv2d_1a_3x3,
127 | inception.Conv2d_2a_3x3,
128 | inception.Conv2d_2b_3x3,
129 | nn.MaxPool2d(kernel_size=3, stride=2)
130 | ]
131 | self.blocks.append(nn.Sequential(*block0))
132 |
133 | # Block 1: maxpool1 to maxpool2
134 | if self.last_needed_block >= 1:
135 | block1 = [
136 | inception.Conv2d_3b_1x1,
137 | inception.Conv2d_4a_3x3,
138 | nn.MaxPool2d(kernel_size=3, stride=2)
139 | ]
140 | self.blocks.append(nn.Sequential(*block1))
141 |
142 | # Block 2: maxpool2 to aux classifier
143 | if self.last_needed_block >= 2:
144 | block2 = [
145 | inception.Mixed_5b,
146 | inception.Mixed_5c,
147 | inception.Mixed_5d,
148 | inception.Mixed_6a,
149 | inception.Mixed_6b,
150 | inception.Mixed_6c,
151 | inception.Mixed_6d,
152 | inception.Mixed_6e,
153 | ]
154 | self.blocks.append(nn.Sequential(*block2))
155 |
156 | # Block 3: aux classifier to final avgpool
157 | if self.last_needed_block >= 3:
158 | block3 = [
159 | inception.Mixed_7a,
160 | inception.Mixed_7b,
161 | inception.Mixed_7c,
162 | nn.AdaptiveAvgPool2d(output_size=(1, 1))
163 | ]
164 | self.blocks.append(nn.Sequential(*block3))
165 |
166 | for param in self.parameters():
167 | param.requires_grad = requires_grad
168 |
169 | def forward(self, inp):
170 | """Get Inception feature maps
171 |
172 | Parameters
173 | ----------
174 | inp : torch.autograd.Variable
175 | Input tensor of shape Bx3xHxW. Values are expected to be in
176 | range (0, 1)
177 |
178 | Returns
179 | -------
180 | List of torch.autograd.Variable, corresponding to the selected output
181 | block, sorted ascending by index
182 | """
183 | outp = []
184 | x = inp
185 |
186 | if self.resize_input:
187 | x = F.upsample(x, size=(299, 299), mode='bilinear', align_corners=True)
188 |
189 | if self.normalize_input:
190 | x = x.clone()
191 | x[:, 0] = x[:, 0] * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
192 | x[:, 1] = x[:, 1] * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
193 | x[:, 2] = x[:, 2] * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
194 |
195 | for idx, block in enumerate(self.blocks):
196 | x = block(x)
197 | if idx in self.output_blocks:
198 | outp.append(x)
199 |
200 | if idx == self.last_needed_block:
201 | break
202 |
203 | return outp
204 |
205 | def extract_real_feature(data_loader, inception, device, t2i=False):
206 | feats = []
207 |
208 | if t2i:
209 | for img, tokens, txt in tqdm(data_loader):
210 | img = img.to(device)
211 | feat = inception(img)
212 |
213 | feats.append(feat)
214 | else:
215 | for img in tqdm(data_loader):
216 | img = img.to(device)
217 | feat = inception(img)
218 |
219 | feats.append(feat)
220 |
221 | feats = gather_feats(feats)
222 |
223 | return feats
224 |
225 | def normalize_fake_img(imgs):
226 | """
227 | mean = [0.485, 0.456, 0.406]
228 | std = [0.229, 0.224, 0.225]
229 |
230 | imgs = (imgs + 1) / 2 # -1 ~ 1 to 0~1
231 | imgs = torch.clamp(imgs, 0, 1)
232 | imgs = transforms.Resize(size=[299, 299], interpolation=transforms.InterpolationMode.BICUBIC, antialias=True)(imgs),
233 | imgs = transforms.Normalize(mean=mean, std=std)(imgs)
234 | """
235 |
236 | norm = transforms.Compose([
237 | transforms.Normalize((-1, -1, -1), (2, 2, 2)), # (x - (-1)) / 2 = (x + 1) / 2
238 | transforms.Resize((299, 299)),
239 | ])
240 |
241 | x = norm(imgs)
242 |
243 | return x
244 |
245 | def gather_feats(feats):
246 | feats = torch.cat(feats, dim=0)
247 | feats = torch.cat(GatherLayer.apply(feats), dim=0)
248 | feats = feats.detach().cpu().numpy()
249 |
250 | return feats
251 |
252 | def extract_fake_feature(generator, inception, num_gpus, device, latent_dim, fake_samples=50000, batch_size=16):
253 | num_batches = int(math.ceil(float(fake_samples) / float(batch_size * num_gpus)))
254 | feats = []
255 | for _ in tqdm(range(num_batches)):
256 | z = [torch.randn([batch_size, latent_dim], device=device)]
257 | fake_img = generator(z)
258 |
259 | fake_img = normalize_fake_img(fake_img)
260 |
261 | feat = inception(fake_img)
262 |
263 | feats.append(feat)
264 |
265 | feats = gather_feats(feats)
266 |
267 | return feats
268 |
269 | def extract_fake_feature_t2i(data_loader, generator, inception, clip_text, device, latent_dim=100, mixed_flag=False):
270 | # with torch.cuda.amp.autocast() if mixed_flag else dummy_context_mgr() as mpc:
271 | with torch.no_grad():
272 | feats = []
273 | try:
274 | for img, tokens, txt in tqdm(data_loader):
275 | # pre-process
276 | tokens = tokens.to(device)
277 | sent_emb, word_emb = clip_text(tokens) # [bs, 512], [bs, 77, 512]
278 | sent_emb = sent_emb.detach()
279 |
280 | # make fake_img
281 | noise = torch.randn([sent_emb.shape[0], latent_dim]).to(device)
282 | fake_img = generator(noise, sent_emb)
283 | fake_img = fake_img.float()
284 | fake_img = torch.clamp(fake_img, -1., 1.)
285 | fake_img = torch.nan_to_num(fake_img, nan=-1.0, posinf=1.0, neginf=-1.0)
286 |
287 | # get features of inception
288 | fake_img = normalize_fake_img(fake_img)
289 | feat = inception(fake_img)
290 |
291 | # galip
292 | pred = feat[0]
293 | if pred.shape[2] != 1 or pred.shape[3] != 1:
294 | pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
295 | pred = pred.squeeze(-1).squeeze(-1)
296 | feats.append(pred)
297 |
298 | except IndexError:
299 | pass
300 |
301 | feats = gather_feats(feats)
302 |
303 | return feats
304 |
305 | def get_statistics(feats):
306 | mu = np.mean(feats, axis=0)
307 | cov = np.cov(feats, rowvar=False)
308 |
309 | return mu, cov
310 |
311 | def frechet_distance(mu, cov, mu2, cov2):
312 | cc, _ = linalg.sqrtm(np.dot(cov, cov2), disp=False)
313 | dist = np.sum((mu -mu2)**2) + np.trace(cov + cov2 - 2*cc)
314 | return np.real(dist)
315 |
316 | @torch.no_grad()
317 | def calculate_fid(data_loader, generator_model, inception_model, dataset_name, rank, device,
318 | latent_dim, fake_samples=50000, batch_size=16):
319 |
320 | num_gpus = torch.cuda.device_count()
321 |
322 | generator_model = generator_model.eval()
323 | inception_model = inception_model.eval()
324 |
325 | pickle_name = '{}_mu_cov.pickle'.format(dataset_name)
326 | cache = os.path.exists(pickle_name)
327 |
328 | if cache:
329 | with open(pickle_name, 'rb') as f:
330 | real_mu, real_cov = pickle.load(f)
331 | else:
332 | real_feats = extract_real_feature(data_loader, inception_model, device=device)
333 | real_mu, real_cov = get_statistics(real_feats)
334 |
335 | if rank == 0:
336 | with open(pickle_name, 'wb') as f:
337 | pickle.dump((real_mu, real_cov), f, protocol=pickle.HIGHEST_PROTOCOL)
338 |
339 |
340 | fake_feats = extract_fake_feature(generator_model, inception_model, num_gpus, device, latent_dim, fake_samples, batch_size)
341 | fake_mu, fake_cov = get_statistics(fake_feats)
342 |
343 | fid = frechet_distance(real_mu, real_cov, fake_mu, fake_cov)
344 |
345 | return fid
346 |
347 | @torch.no_grad()
348 | def calculate_fid_t2i(data_loader, generator, inception, clip_text, dataset_name, device,
349 | latent_dim=100, mixed_flag=False):
350 | # coco: 5000
351 |
352 | generator = generator.eval()
353 | inception = inception.eval()
354 | clip_text = clip_text.eval()
355 |
356 | stats_path = '{}_fid_stats.npz'.format(dataset_name)
357 | x = np.load(stats_path)
358 | real_mu, real_cov = x['mu'], x['sigma']
359 |
360 |
361 | fake_feats = extract_fake_feature_t2i(data_loader, generator, inception, clip_text, device, latent_dim, mixed_flag)
362 | fake_mu, fake_cov = get_statistics(fake_feats)
363 |
364 | fid = frechet_distance(real_mu, real_cov, fake_mu, fake_cov)
365 |
366 | return fid
--------------------------------------------------------------------------------
/src/GALIP/networks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | import torch.nn.functional as F
5 | from collections import OrderedDict
6 | import clip
7 | # clip : CLIP4evl
8 |
9 | """
10 | Compose(
11 | Resize(size=224, interpolation=bicubic, max_size=None, antialias=None)
12 | CenterCrop(size=(224, 224))
13 |
14 | ToTensor()
15 | Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
16 | )
17 | """
18 |
19 | class dummy_context_mgr():
20 | def __enter__(self):
21 | return None
22 |
23 | def __exit__(self, exc_type, exc_value, traceback):
24 | return False
25 |
26 | class CLIP_IMG_ENCODER(nn.Module):
27 | def __init__(self, CLIP):
28 | super(CLIP_IMG_ENCODER, self).__init__()
29 | model = CLIP.visual
30 | # print(model)
31 | self.define_module(model)
32 | for param in self.parameters():
33 | param.requires_grad = False
34 |
35 | def define_module(self, model):
36 | self.conv1 = model.conv1
37 | self.class_embedding = model.class_embedding
38 | self.positional_embedding = model.positional_embedding
39 | self.ln_pre = model.ln_pre
40 | self.transformer = model.transformer
41 | self.ln_post = model.ln_post
42 | self.proj = model.proj
43 |
44 | @property
45 | def dtype(self):
46 | return self.conv1.weight.dtype
47 |
48 | def transf_to_CLIP_input(self,inputs):
49 | device = inputs.device
50 | if len(inputs.size()) != 4:
51 | raise ValueError('Expect the (B, C, X, Y) tensor.')
52 | else:
53 | mean = torch.tensor([0.48145466, 0.4578275, 0.40821073])\
54 | .unsqueeze(-1).unsqueeze(-1).unsqueeze(0).to(device)
55 | var = torch.tensor([0.26862954, 0.26130258, 0.27577711])\
56 | .unsqueeze(-1).unsqueeze(-1).unsqueeze(0).to(device)
57 | inputs = F.interpolate(inputs*0.5+0.5, size=(224, 224))
58 | # inputs = ((inputs+1)*0.5-mean)/var
59 | inputs = (inputs - mean) / var
60 | return inputs
61 |
62 | def forward(self, img: torch.Tensor):
63 | x = self.transf_to_CLIP_input(img)
64 | x = x.type(self.dtype)
65 | x = self.conv1(x) # shape = [*, width, grid, grid]
66 | grid = x.size(-1)
67 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
68 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
69 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
70 | x = x + self.positional_embedding.to(x.dtype)
71 | x = self.ln_pre(x)
72 | # NLD -> LND
73 | x = x.permute(1, 0, 2)
74 | # Local features
75 | #selected = [1,4,7,12]
76 | selected = [1,4,8]
77 | local_features = []
78 | for i in range(12):
79 | x = self.transformer.resblocks[i](x)
80 | if i in selected:
81 | local_features.append(x.permute(1, 0, 2)[:, 1:, :].permute(0, 2, 1).reshape(-1, 768, grid, grid).contiguous().type(img.dtype))
82 | x = x.permute(1, 0, 2) # LND -> NLD
83 | x = self.ln_post(x[:, 0, :])
84 | if self.proj is not None:
85 | x = x @ self.proj
86 | return torch.stack(local_features, dim=1), x.type(img.dtype)
87 |
88 |
89 | class CLIP_TXT_ENCODER(nn.Module):
90 | def __init__(self, CLIP):
91 | super(CLIP_TXT_ENCODER, self).__init__()
92 | self.define_module(CLIP)
93 | # print(model)
94 | for param in self.parameters():
95 | param.requires_grad = False
96 |
97 | def define_module(self, CLIP):
98 | self.transformer = CLIP.transformer
99 | self.vocab_size = CLIP.vocab_size
100 | self.token_embedding = CLIP.token_embedding
101 | self.positional_embedding = CLIP.positional_embedding
102 | self.ln_final = CLIP.ln_final
103 | self.text_projection = CLIP.text_projection
104 |
105 | @property
106 | def dtype(self):
107 | return self.transformer.resblocks[0].mlp.c_fc.weight.dtype
108 |
109 | def forward(self, text):
110 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
111 | x = x + self.positional_embedding.type(self.dtype)
112 | x = x.permute(1, 0, 2) # NLD -> LND
113 | x = self.transformer(x)
114 | x = x.permute(1, 0, 2) # LND -> NLD
115 | x = self.ln_final(x).type(self.dtype)
116 | # x.shape = [batch_size, n_ctx, transformer.width]
117 | # take features from the eot embedding (eot_token is the highest number in each sequence)
118 | sent_emb = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
119 | return sent_emb, x
120 |
121 |
122 | class CLIP_Mapper(nn.Module):
123 | def __init__(self, CLIP):
124 | super(CLIP_Mapper, self).__init__()
125 | model = CLIP.visual
126 | # print(model)
127 | self.define_module(model)
128 | for param in model.parameters():
129 | param.requires_grad = False
130 |
131 | def define_module(self, model):
132 | self.conv1 = model.conv1
133 | self.class_embedding = model.class_embedding
134 | self.positional_embedding = model.positional_embedding
135 | self.ln_pre = model.ln_pre
136 | self.transformer = model.transformer
137 |
138 | @property
139 | def dtype(self):
140 | return self.conv1.weight.dtype
141 |
142 | def forward(self, img: torch.Tensor, prompts: torch.Tensor):
143 | x = img.type(self.dtype)
144 | prompts = prompts.type(self.dtype)
145 | grid = x.size(-1)
146 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
147 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
148 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)
149 | # shape = [*, grid ** 2 + 1, width]
150 | x = x + self.positional_embedding.to(x.dtype)
151 | x = self.ln_pre(x)
152 | # NLD -> LND
153 | x = x.permute(1, 0, 2)
154 | # Local features
155 | selected = [1,2,3,4,5,6,7,8]
156 | begin, end = 0, 12
157 | prompt_idx = 0
158 | for i in range(begin, end):
159 | if i in selected:
160 | prompt = prompts[:,prompt_idx,:].unsqueeze(0)
161 | prompt_idx = prompt_idx+1
162 | x = torch.cat((x,prompt), dim=0)
163 | x = self.transformer.resblocks[i](x)
164 | x = x[:-1,:,:]
165 | else:
166 | x = self.transformer.resblocks[i](x)
167 | return x.permute(1, 0, 2)[:, 1:, :].permute(0, 2, 1).reshape(-1, 768, grid, grid).contiguous().type(img.dtype)
168 |
169 |
170 | class CLIP_Adapter(nn.Module):
171 | def __init__(self, in_ch, mid_ch, out_ch, G_ch, CLIP_ch, cond_dim, k, s, p, map_num, CLIP):
172 | super(CLIP_Adapter, self).__init__()
173 | self.CLIP_ch = CLIP_ch
174 | self.FBlocks = nn.ModuleList([])
175 | self.FBlocks.append(M_Block(in_ch, mid_ch, out_ch, cond_dim, k, s, p))
176 | for i in range(map_num-1):
177 | self.FBlocks.append(M_Block(out_ch, mid_ch, out_ch, cond_dim, k, s, p))
178 | self.conv_fuse = nn.Conv2d(out_ch, CLIP_ch, 5, 1, 2)
179 | self.CLIP_ViT = CLIP_Mapper(CLIP)
180 | self.conv = nn.Conv2d(768, G_ch, 5, 1, 2)
181 | #
182 | self.fc_prompt = nn.Linear(cond_dim, CLIP_ch*8)
183 |
184 | def forward(self,out,c):
185 | prompts = self.fc_prompt(c).view(c.size(0),-1,self.CLIP_ch)
186 | # [1, 8, 768]
187 | for FBlock in self.FBlocks:
188 | out = FBlock(out,c)
189 | # out -> [1, 64, 7, 7]
190 | fuse_feat = self.conv_fuse(out)
191 | # fuse_feat -> [1, 768, 7, 7]
192 | map_feat = self.CLIP_ViT(fuse_feat,prompts)
193 | # map_feat -> [1, 768, 7, 7]
194 | return self.conv(fuse_feat+0.1*map_feat) # [1, 512, 7, 7]
195 |
196 |
197 | class NetG(nn.Module):
198 | def __init__(self, imsize, CLIP, ngf=64, nz=100, cond_dim=512, ch_size=3, mixed_precision=False):
199 | super(NetG, self).__init__()
200 | self.ngf = ngf
201 | self.mixed_precision = mixed_precision
202 | # build CLIP Mapper
203 | self.code_sz, self.code_ch, self.mid_ch = 7, 64, 32
204 | self.CLIP_ch = 768
205 | self.fc_code = nn.Linear(nz, self.code_sz*self.code_sz*self.code_ch)
206 | self.mapping = CLIP_Adapter(self.code_ch, self.mid_ch, self.code_ch, ngf*8, self.CLIP_ch, cond_dim+nz, 3, 1, 1, 4, CLIP)
207 | # build GBlocks
208 | self.GBlocks = nn.ModuleList([])
209 | in_out_pairs = list(get_G_in_out_chs(ngf, imsize))
210 | imsize = 4
211 | for idx, (in_ch, out_ch) in enumerate(in_out_pairs):
212 | if idx<(len(in_out_pairs)-1):
213 | imsize = imsize*2
214 | else:
215 | imsize = 224
216 | self.GBlocks.append(G_Block(cond_dim+nz, in_ch, out_ch, imsize))
217 | # to RGB image
218 | self.to_rgb = nn.Sequential(
219 | nn.LeakyReLU(0.2,inplace=True),
220 | nn.Conv2d(out_ch, ch_size, 3, 1, 1),
221 | #nn.Tanh(),
222 | )
223 |
224 | def forward(self, noise, c, eval=False): # x=noise, c=ent_emb
225 | with torch.cuda.amp.autocast() if self.mixed_precision and not eval else dummy_context_mgr() as mp:
226 | cond = torch.cat((noise, c), dim=1) # 612 dim, 100 + 512
227 | out = self.mapping(self.fc_code(noise).view(noise.size(0), self.code_ch, self.code_sz, self.code_sz), cond)
228 | # fc_code -> [1, 64, 7, 7]
229 | # out -> [1, 512, 7, 7]
230 | # fuse text and visual features
231 | # 이미지 늘리기
232 | for GBlock in self.GBlocks:
233 | out = GBlock(out, cond)
234 | # [1, 64, 224, 224]
235 | # convert to RGB image
236 | out = self.to_rgb(out)
237 | return out
238 |
239 |
240 | class NetD(nn.Module):
241 | def __init__(self, imsize, ndf=64, ch_size=3, mixed_precision=False):
242 | super(NetD, self).__init__()
243 | self.mixed_precision = mixed_precision
244 | self.DBlocks = nn.ModuleList([
245 | D_Block(768, 768, 3, 1, 1, res=True, CLIP_feat=True),
246 | D_Block(768, 768, 3, 1, 1, res=True, CLIP_feat=True),
247 | ])
248 | self.main = D_Block(768, 512, 3, 1, 1, res=True, CLIP_feat=False)
249 |
250 | def forward(self, h):
251 | with torch.cuda.amp.autocast() if self.mixed_precision else dummy_context_mgr() as mpc:
252 | out = h[:,0]
253 | for idx in range(len(self.DBlocks)):
254 | out = self.DBlocks[idx](out, h[:,idx+1])
255 | out = self.main(out)
256 | return out
257 |
258 |
259 | class NetC(nn.Module):
260 | def __init__(self, ndf=64, cond_dim=512, mixed_precision=False):
261 | super(NetC, self).__init__()
262 | self.cond_dim = cond_dim
263 | self.mixed_precision = mixed_precision
264 | self.joint_conv = nn.Sequential(
265 | nn.Conv2d(512+512, 128, 4, 1, 0, bias=False),
266 | nn.LeakyReLU(0.2, inplace=True),
267 | nn.Conv2d(128, 1, 4, 1, 0, bias=False),
268 | )
269 |
270 | def forward(self, out, cond):
271 | with torch.cuda.amp.autocast() if self.mixed_precision else dummy_context_mgr() as mpc:
272 | cond = cond.view(-1, self.cond_dim, 1, 1)
273 | cond = cond.repeat(1, 1, 7, 7)
274 | h_c_code = torch.cat((out, cond), 1)
275 | out = self.joint_conv(h_c_code)
276 | return out
277 |
278 |
279 | class M_Block(nn.Module):
280 | def __init__(self, in_ch, mid_ch, out_ch, cond_dim, k, s, p):
281 | super(M_Block, self).__init__()
282 | self.conv1 = nn.Conv2d(in_ch, mid_ch, k, s, p)
283 | self.fuse1 = DFBLK(cond_dim, mid_ch)
284 | self.conv2 = nn.Conv2d(mid_ch, out_ch, k, s, p)
285 | self.fuse2 = DFBLK(cond_dim, out_ch)
286 | self.learnable_sc = in_ch != out_ch
287 | if self.learnable_sc:
288 | self.c_sc = nn.Conv2d(in_ch, out_ch, 1, stride=1, padding=0)
289 |
290 | def shortcut(self, x):
291 | if self.learnable_sc:
292 | x = self.c_sc(x)
293 | return x
294 |
295 | def residual(self, h, text):
296 | h = self.conv1(h)
297 | h = self.fuse1(h, text)
298 | h = self.conv2(h)
299 | h = self.fuse2(h, text)
300 | return h
301 |
302 | def forward(self, h, c):
303 | return self.shortcut(h) + self.residual(h, c)
304 |
305 |
306 | class G_Block(nn.Module):
307 | def __init__(self, cond_dim, in_ch, out_ch, imsize):
308 | super(G_Block, self).__init__()
309 | self.imsize = imsize
310 | self.learnable_sc = in_ch != out_ch
311 | self.c1 = nn.Conv2d(in_ch, out_ch, 3, 1, 1)
312 | self.c2 = nn.Conv2d(out_ch, out_ch, 3, 1, 1)
313 | self.fuse1 = DFBLK(cond_dim, in_ch)
314 | self.fuse2 = DFBLK(cond_dim, out_ch)
315 | if self.learnable_sc:
316 | self.c_sc = nn.Conv2d(in_ch,out_ch, 1, stride=1, padding=0)
317 |
318 | def shortcut(self, x):
319 | if self.learnable_sc:
320 | x = self.c_sc(x)
321 | return x
322 |
323 | def residual(self, h, y):
324 | h = self.fuse1(h, y)
325 | h = self.c1(h)
326 | h = self.fuse2(h, y)
327 | h = self.c2(h)
328 | return h
329 |
330 | def forward(self, h, y):
331 | h = F.interpolate(h, size=(self.imsize, self.imsize))
332 | return self.shortcut(h) + self.residual(h, y)
333 |
334 |
335 | class D_Block(nn.Module):
336 | def __init__(self, fin, fout, k, s, p, res, CLIP_feat):
337 | super(D_Block, self).__init__()
338 | self.res, self.CLIP_feat = res, CLIP_feat
339 | self.learned_shortcut = (fin != fout)
340 | self.conv_r = nn.Sequential(
341 | nn.Conv2d(fin, fout, k, s, p, bias=False),
342 | nn.LeakyReLU(0.2, inplace=True),
343 | nn.Conv2d(fout, fout, k, s, p, bias=False),
344 | nn.LeakyReLU(0.2, inplace=True),
345 | )
346 | self.conv_s = nn.Conv2d(fin, fout, 1, stride=1, padding=0)
347 | if self.res==True:
348 | self.gamma = nn.Parameter(torch.zeros(1))
349 | if self.CLIP_feat==True:
350 | self.beta = nn.Parameter(torch.zeros(1))
351 |
352 | def forward(self, x, CLIP_feat=None):
353 | res = self.conv_r(x)
354 | if self.learned_shortcut:
355 | x = self.conv_s(x)
356 | if (self.res==True)and(self.CLIP_feat==True):
357 | return x + self.gamma*res + self.beta*CLIP_feat
358 | elif (self.res==True)and(self.CLIP_feat!=True):
359 | return x + self.gamma*res
360 | elif (self.res!=True)and(self.CLIP_feat==True):
361 | return x + self.beta*CLIP_feat
362 | else:
363 | return x
364 |
365 |
366 | class DFBLK(nn.Module):
367 | def __init__(self, cond_dim, in_ch):
368 | super(DFBLK, self).__init__()
369 | self.affine0 = Affine(cond_dim, in_ch)
370 | self.affine1 = Affine(cond_dim, in_ch)
371 |
372 | def forward(self, x, y=None):
373 | h = self.affine0(x, y)
374 | h = nn.LeakyReLU(0.2,inplace=True)(h)
375 | h = self.affine1(h, y)
376 | h = nn.LeakyReLU(0.2,inplace=True)(h)
377 | return h
378 |
379 |
380 | class QuickGELU(nn.Module):
381 | def forward(self, x: torch.Tensor):
382 | return x * torch.sigmoid(1.702 * x)
383 |
384 |
385 | class Affine(nn.Module):
386 | def __init__(self, cond_dim, num_features):
387 | super(Affine, self).__init__()
388 |
389 | self.fc_gamma = nn.Sequential(OrderedDict([
390 | ('linear1',nn.Linear(cond_dim, num_features)),
391 | ('relu1',nn.ReLU(inplace=True)),
392 | ('linear2',nn.Linear(num_features, num_features)),
393 | ]))
394 | self.fc_beta = nn.Sequential(OrderedDict([
395 | ('linear1',nn.Linear(cond_dim, num_features)),
396 | ('relu1',nn.ReLU(inplace=True)),
397 | ('linear2',nn.Linear(num_features, num_features)),
398 | ]))
399 | self._initialize()
400 |
401 | def _initialize(self):
402 | nn.init.zeros_(self.fc_gamma.linear2.weight.data)
403 | nn.init.ones_(self.fc_gamma.linear2.bias.data)
404 | nn.init.zeros_(self.fc_beta.linear2.weight.data)
405 | nn.init.zeros_(self.fc_beta.linear2.bias.data)
406 |
407 | def forward(self, x, y=None):
408 | weight = self.fc_gamma(y)
409 | bias = self.fc_beta(y)
410 |
411 | if weight.dim() == 1:
412 | weight = weight.unsqueeze(0)
413 | if bias.dim() == 1:
414 | bias = bias.unsqueeze(0)
415 |
416 | size = x.size()
417 | weight = weight.unsqueeze(-1).unsqueeze(-1).expand(size)
418 | bias = bias.unsqueeze(-1).unsqueeze(-1).expand(size)
419 | return weight * x + bias
420 |
421 |
422 | def get_G_in_out_chs(nf, imsize):
423 | layer_num = int(np.log2(imsize))-1
424 | channel_nums = [nf*min(2**idx, 8) for idx in range(layer_num)]
425 | channel_nums = channel_nums[::-1]
426 | in_out_pairs = zip(channel_nums[:-1], channel_nums[1:])
427 | return in_out_pairs
428 |
429 |
430 | def get_D_in_out_chs(nf, imsize):
431 | layer_num = int(np.log2(imsize))-1
432 | channel_nums = [nf*min(2**idx, 8) for idx in range(layer_num)]
433 | in_out_pairs = zip(channel_nums[:-1], channel_nums[1:])
434 | return in_out_pairs
--------------------------------------------------------------------------------
/src/GALIP/ops.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import torch.nn as nn
4 |
5 | from torchvision import transforms
6 |
7 | def get_ratio(reg_every):
8 | if reg_every == 0:
9 | reg_ratio = 1
10 | else:
11 | reg_ratio = float(reg_every) / float(reg_every + 1)
12 |
13 | return reg_ratio
14 |
15 | def apply_gradients(loss, optim, mixed_flag=False, scaler_x=None, scaler_min=None):
16 | optim.zero_grad()
17 |
18 | if mixed_flag:
19 | scaler_x.scale(loss).backward()
20 | scaler_x.step(optim)
21 | scaler_x.update()
22 | if scaler_x.get_scale() < scaler_min:
23 | scaler_x.update(16384.0)
24 | else:
25 | loss.backward()
26 | optim.step()
27 |
28 | def moving_average(ema_model, origin_model, decay=0.999):
29 | # model1 = ema
30 | # model2 = origin
31 |
32 | with torch.no_grad():
33 | ema_param = dict(ema_model.named_parameters())
34 | origin_param = dict(origin_model.named_parameters())
35 |
36 | for k in ema_param.keys():
37 | ema_param[k].data.mul_(decay).add_(origin_param[k].data, alpha=1 - decay)
38 | # ema_param[k].data = decay * ema_param[k].data + (1 - decay) * origin_param[k].data
39 |
40 | def d_hinge_loss(real_pred, fake_pred, fake_pred2):
41 | real_loss = torch.mean(F.relu(1.0 - real_pred))
42 | fake_loss = torch.mean(F.relu(1.0 + fake_pred))
43 | if fake_pred2 is None:
44 | d_loss = real_loss + fake_loss
45 | else:
46 | fake_loss2 = torch.mean(F.relu(1.0 + fake_pred2))
47 | fake_loss = (fake_loss + fake_loss2) * 0.5
48 | d_loss = real_loss + fake_loss
49 | return d_loss
50 |
51 | def g_hinge_loss(fake_pred):
52 | g_loss = -torch.mean(fake_pred)
53 | return g_loss
54 |
55 |
56 | def d_logistic_loss(real_pred, fake_pred, fake_pred2):
57 | real_loss = F.softplus(-real_pred)
58 | fake_loss = F.softplus(fake_pred)
59 |
60 | if fake_pred2 is None:
61 | return real_loss.mean() + fake_loss.mean()
62 | else:
63 | fake_loss2 = F.softplus(fake_pred2)
64 | return real_loss.mean() + (fake_loss.mean() + fake_loss2.mean()) * 0.5
65 |
66 | def d_r1_loss(logits, real_img, text_embed=None):
67 | if text_embed is None:
68 | grad_real = torch.autograd.grad(
69 | outputs=logits.sum(),
70 | inputs=real_img,
71 | create_graph=True,
72 | )[0]
73 | grad_penalty = (grad_real ** 2).reshape(grad_real.shape[0], -1).sum(1).mean()
74 |
75 | else:
76 | grads = torch.autograd.grad(
77 | outputs=logits.sum(),
78 | inputs=(real_img, text_embed),
79 | create_graph=True,
80 | )
81 | grad0 = grads[0].view(grads[0].size(0), -1)
82 | grad1 = grads[1].view(grads[1].size(0), -1)
83 | grad = torch.cat((grad0, grad1), dim=1)
84 | # norm은 torch.sqrt((grad ** 2).sum(1)) 임
85 | grad_penalty = (grad ** 2).sum(1).mean()
86 |
87 | return grad_penalty
88 |
89 | def g_nonsaturating_loss(fake_pred):
90 | loss = F.softplus(-fake_pred).mean()
91 |
92 | return loss
93 |
94 | def d_adv_loss(real_pred, fake_pred, fake_pred2=None, gan_type='gan'):
95 | if gan_type == 'hinge':
96 | loss = d_hinge_loss(real_pred, fake_pred, fake_pred2)
97 | else:
98 | loss = d_logistic_loss(real_pred, fake_pred, fake_pred2)
99 |
100 | return loss
101 |
102 | def g_adv_loss(fake_pred, gan_type='gan'):
103 | if gan_type == 'hinge':
104 | loss = g_hinge_loss(fake_pred)
105 | else:
106 | loss = g_nonsaturating_loss(fake_pred)
107 |
108 | return loss
109 |
110 |
111 | def predict_loss(predictor, img_feature, text_feature, negtive):
112 | output = predictor(img_feature, text_feature)
113 | err = hinge_loss(output, negtive)
114 | return output,err
115 |
116 | def hinge_loss(output, negtive):
117 | if negtive==False:
118 | err = torch.mean(F.relu(1. - output))
119 | else:
120 | err = torch.mean(F.relu(1. + output))
121 | return err
122 |
123 | def MA_GP_FP32(img, sent, out):
124 | grads = torch.autograd.grad(outputs=out,
125 | inputs=(img, sent),
126 | grad_outputs=torch.ones(out.size()).cuda(),
127 | retain_graph=True,
128 | create_graph=True,
129 | only_inputs=True)
130 | grad0 = grads[0].view(grads[0].size(0), -1)
131 | grad1 = grads[1].view(grads[1].size(0), -1)
132 | grad = torch.cat((grad0,grad1),dim=1)
133 | grad_l2norm = torch.sqrt(torch.sum(grad ** 2, dim=1))
134 | d_loss_gp = 2.0 * torch.mean((grad_l2norm) ** 6)
135 | return d_loss_gp
136 |
137 | def MA_GP_MP(img, sent, out, scaler):
138 | grads = torch.autograd.grad(outputs=scaler.scale(out),
139 | inputs=(img, sent),
140 | grad_outputs=torch.ones_like(out),
141 | retain_graph=True,
142 | create_graph=True,
143 | only_inputs=True)
144 | inv_scale = 1./(scaler.get_scale()+float("1e-8"))
145 | #inv_scale = 1./scaler.get_scale()
146 | grads = [grad * inv_scale for grad in grads]
147 | with torch.cuda.amp.autocast():
148 | grad0 = grads[0].view(grads[0].size(0), -1)
149 | grad1 = grads[1].view(grads[1].size(0), -1)
150 | grad = torch.cat((grad0,grad1),dim=1)
151 | grad_l2norm = torch.sqrt(torch.sum(grad ** 2, dim=1))
152 | d_loss_gp = 2.0 * torch.mean((grad_l2norm) ** 6)
153 | return d_loss_gp
154 |
155 | # clip loss
156 | def clip_image_process(x):
157 | def denormalize(x):
158 | # [-1, 1] ~ [0, 255]
159 | x = ((x + 1) / 2 * 255).clamp(0, 255).to(torch.uint8)
160 |
161 | return x
162 |
163 | def resize(x):
164 | x = transforms.Resize(size=[224, 224], interpolation=transforms.InterpolationMode.BICUBIC, antialias=True)(x)
165 | return x
166 |
167 | def zero_to_one(x):
168 | x = x.float() / 255.0
169 | return x
170 |
171 | def norm_mean_std(x):
172 | mean = [0.48145466, 0.4578275, 0.40821073]
173 | std = [0.26862954, 0.26130258, 0.27577711]
174 | x = transforms.Normalize(mean=mean, std=std, inplace=True)(x)
175 | return x
176 |
177 |
178 | x = denormalize(x)
179 | x = resize(x)
180 | x = zero_to_one(x)
181 | x = norm_mean_std(x)
182 |
183 | return x
184 |
185 | def cosine_sim_loss(image_feat, text_feat):
186 | image_feat = image_feat / image_feat.norm(p=2, dim=-1, keepdim=True)
187 | text_feat = text_feat / text_feat.norm(p=2, dim=-1, keepdim=True)
188 |
189 | loss = -F.cosine_similarity(image_feat, text_feat).mean()
190 | return loss
191 |
192 | def clip_score(clip_model, image, text):
193 | txt_features = clip_model.get_text_features(text)
194 |
195 | processed_image = clip_image_process(image)
196 | img_features = clip_model.get_image_features(processed_image)
197 |
198 | img_features = img_features / img_features.norm(p=2, dim=-1, keepdim=True)
199 | txt_features = txt_features / txt_features.norm(p=2, dim=-1, keepdim=True)
200 |
201 | # score = 100 * (img_features * txt_features).sum(axis=-1)
202 | # score = torch.mean(score)
203 |
204 | score = -F.cosine_similarity(img_features, txt_features).mean()
205 |
206 | return score
207 |
208 | def clip_image_score(clip_model, image1, image2):
209 | processed_image1 = clip_image_process(image1)
210 | processed_image2 = clip_image_process(image2)
211 |
212 | img_features1 = clip_model.get_image_features(processed_image1)
213 | img_features2 = clip_model.get_image_features(processed_image2)
214 |
215 | img_features1 = img_features1 / img_features1.norm(p=2, dim=-1, keepdim=True)
216 | img_features2 = img_features2 / img_features2.norm(p=2, dim=-1, keepdim=True)
217 |
218 | # score = 100 * (img_features1 * img_features2).sum(axis=-1)
219 | # score = torch.mean(score)
220 |
221 | score = -F.cosine_similarity(img_features1, img_features2).mean()
222 |
223 | return score
224 |
225 | def contrastive_loss(logits, dim) :
226 | neg_ce = torch.diag(nn.functional.log_softmax(logits, dim=dim))
227 | return -neg_ce.mean()
228 |
229 | def clip_score_(clip_model, image, text):
230 | txt_features = clip_model.get_text_features(text)
231 |
232 | processed_image = clip_image_process(image)
233 | img_features = clip_model.get_image_features(processed_image)
234 |
235 | img_features = img_features / img_features.norm(p=2, dim=-1, keepdim=True)
236 | txt_features = txt_features / txt_features.norm(p=2, dim=-1, keepdim=True)
237 |
238 | # cosine similarity as logits
239 | logit_scale = clip_model.logit_scale.exp()
240 | similarity = torch.matmul(txt_features, img_features.t()) * logit_scale
241 |
242 | caption_loss = contrastive_loss(similarity, dim=0)
243 | image_loss = contrastive_loss(similarity, dim=1)
244 |
245 | return (caption_loss + image_loss) / 2.0
246 |
247 | def clip_image_score_(clip_model, image1, image2):
248 | processed_image1 = clip_image_process(image1)
249 | processed_image2 = clip_image_process(image2)
250 |
251 | img_features1 = clip_model.get_image_features(processed_image1)
252 | img_features2 = clip_model.get_image_features(processed_image2)
253 |
254 | img_features1 = img_features1 / img_features1.norm(p=2, dim=-1, keepdim=True)
255 | img_features2 = img_features2 / img_features2.norm(p=2, dim=-1, keepdim=True)
256 |
257 | # cosine similarity as logits
258 | logit_scale = clip_model.logit_scale.exp()
259 | similarity = torch.matmul(img_features1, img_features2.t()) * logit_scale
260 |
261 | caption_loss = contrastive_loss(similarity, dim=0)
262 | image_loss = contrastive_loss(similarity, dim=1)
263 |
264 | return (caption_loss + image_loss) / 2.0
265 |
266 | def convert_to_billion_and_million(value, decimal_places=2):
267 | billion = round(value / 1_000_000_000, decimal_places)
268 | million = round(value / 1_000_000, decimal_places)
269 |
270 | return billion, million
--------------------------------------------------------------------------------
/src/GALIP/utils.py:
--------------------------------------------------------------------------------
1 | import copy
2 |
3 | import torch
4 | from torchvision import transforms
5 | from torch.utils.data import Dataset
6 | from PIL import Image
7 | import os, re
8 | from glob import glob
9 | import torch.distributed as dist
10 | from torch.nn.parallel import DistributedDataParallel
11 | import torch.multiprocessing as torch_multiprocessing
12 |
13 | import json
14 | import requests
15 | import traceback
16 |
17 | from transformers import CLIPTokenizer, CLIPModel
18 | import torch.nn as nn
19 |
20 | import numpy as np
21 | import torch.nn.functional as F
22 | import random
23 | import clip
24 |
25 | class ImageTextDataset(Dataset):
26 | def __init__(self, dataset_path, img_size, imagenet_normalization=False, max_length=77):
27 | self.image_samples, self.text_samples = self.listdir(dataset_path)
28 | self.max_length = max_length
29 |
30 | transform_list = image_preprocess(img_size, imagenet_normalization)
31 | self.transform = transforms.Compose(transform_list)
32 |
33 | # self.tokenizer, self.clip = FrozenNetwork(max_length=max_length).load()
34 |
35 | def listdir(self, dir_path):
36 | img_extensions = ['png', 'jpg', 'jpeg', 'JPG']
37 | image_list = []
38 | for ext in img_extensions:
39 | image_list += glob(os.path.join(dir_path, 'image', '*.' + ext))
40 | image_list.sort()
41 |
42 | txt_extensions = ['txt']
43 | text_list = []
44 | for ext in txt_extensions:
45 | text_list += glob(os.path.join(dir_path, 'text', '*.' + ext))
46 | text_list.sort()
47 |
48 | return image_list, text_list
49 |
50 | def __getitem__(self, index):
51 | image_path, text_path = self.image_samples[index], self.text_samples[index]
52 | img = Image.open(image_path).convert('RGB')
53 | txt = text_read(text_path)
54 |
55 | img = self.transform(img)
56 |
57 | # batch_encoding = self.tokenizer(txt, truncation=True, max_length=self.max_length, padding="max_length", return_tensors="pt")
58 | # tokens = batch_encoding["input_ids"] # [1, 77]
59 | tokens = clip.tokenize(txt, truncate=True)
60 | tokens = torch.squeeze(tokens)
61 | # tokens = tokens.to(self.clip_text_encoder.device)
62 | # outputs = self.clip_text_encoder(input_ids=tokens)
63 | # txt_embed = outputs.last_hidden_state # [77, 768]
64 |
65 | return img, tokens, txt
66 |
67 | def __len__(self):
68 | return len(self.image_samples)
69 |
70 | class ImageDataset(Dataset):
71 | def __init__(self, dataset_path, img_size, imagenet_normalization=False):
72 | self.image_samples = self.listdir(dataset_path)
73 |
74 | transform_list = image_preprocess(img_size, imagenet_normalization)
75 | self.transform = transforms.Compose(transform_list)
76 |
77 | def listdir(self, dir_path):
78 | img_extensions = ['png', 'jpg', 'jpeg', 'JPG']
79 | image_list = []
80 | for ext in img_extensions:
81 | image_list += glob(os.path.join(dir_path, 'image', '*.' + ext))
82 | image_list.sort()
83 |
84 | return image_list
85 |
86 | def __getitem__(self, index):
87 | image_path = self.image_samples[index]
88 | img = Image.open(image_path).convert('RGB')
89 |
90 | img = self.transform(img)
91 |
92 | return img
93 |
94 | def __len__(self):
95 | return len(self.image_samples)
96 |
97 | class FrozenNetwork(nn.Module):
98 | """Load Clip encoder (for text), SD-Autoencoder (for image)"""
99 | # https://github.com/baofff/U-ViT/blob/f0f35a9e710688ec669ae7154c490a8053f3139f/libs/clip.py
100 | def __init__(self, autoencoder_version="runwayml/stable-diffusion-v1-5", clip_version="openai/clip-vit-large-patch14", max_length=77):
101 | super().__init__()
102 | self.max_length = max_length
103 |
104 | self.tokenizer = CLIPTokenizer.from_pretrained(clip_version)
105 | self.clip = CLIPModel.from_pretrained(clip_version)
106 |
107 | self.freeze()
108 |
109 | def freeze(self):
110 | self.clip.eval()
111 |
112 | def load(self):
113 | return self.tokenizer, self.clip
114 |
115 |
116 | def image_preprocess(img_size, imagenet_normalization=False):
117 | # interpolation=transforms.InterpolationMode.BICUBIC, antialias=True
118 | if imagenet_normalization:
119 | mean = [0.485, 0.456, 0.406]
120 | std = [0.229, 0.224, 0.225]
121 | transform_list = [
122 | transforms.Resize(size=[img_size, img_size], interpolation=transforms.InterpolationMode.BICUBIC, antialias=True), # [h, w]
123 | transforms.ToTensor(), # [0, 255] -> [0, 1] # [c, h, w]
124 | transforms.Normalize(mean=mean, std=std, inplace=True), # [0, 1] -> [-1, 1]
125 | ]
126 | else:
127 | transform_list = [
128 | transforms.Resize(size=[img_size, img_size], interpolation=transforms.InterpolationMode.BICUBIC, antialias=True),
129 | transforms.ToTensor(), # [0, 255] -> [0, 1]
130 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), # [0, 1] -> [-1, 1]
131 | ]
132 |
133 | return transform_list
134 |
135 | def text_read(text_path):
136 | with open(text_path, 'r') as f:
137 | x = f.readlines()
138 |
139 | t = [text.strip() for text in x] # remove \n
140 |
141 | t_sample = random.choice(t)
142 |
143 | return t_sample
144 |
145 |
146 | def check_folder(log_dir):
147 | if not os.path.exists(log_dir):
148 | os.makedirs(log_dir)
149 | return log_dir
150 |
151 |
152 | def str2bool(x):
153 | return x.lower() in ('true')
154 |
155 |
156 | def multi_gpu_run(ddp_fn, args): # in main
157 | # ddp_fn = train_fn
158 | world_size = torch.cuda.device_count() # ngpus
159 | torch_multiprocessing.spawn(fn=ddp_fn, args=(args, world_size), nprocs=world_size, join=True)
160 |
161 |
162 | def build_init_procss(rank, world_size, device): # in build
163 | os.environ["MASTER_ADDR"] = "127.0.0.1" # localhost
164 | os.environ["MASTER_PORT"] = "12355"
165 | dist.init_process_group("nccl", rank=rank, world_size=world_size)
166 | synchronize()
167 | torch.cuda.set_device(device)
168 |
169 |
170 | def distributed_sampler(dataset, rank, num_replicas, shuffle):
171 | return torch.utils.data.distributed.DistributedSampler(dataset, rank=rank, num_replicas=num_replicas, shuffle=shuffle)
172 | # return torch.utils.data.RandomSampler(dataset)
173 |
174 |
175 | def infinite_iterator(loader):
176 | while True:
177 | for batch in loader:
178 | yield batch
179 |
180 | def find_latest_ckpt(folder):
181 | files = []
182 | for fname in os.listdir(folder):
183 | s = re.findall(r'\d+', fname)
184 | if len(s) == 1:
185 | files.append((int(s[0]), fname))
186 | if files:
187 | file_name = max(files)[1]
188 | index = os.path.splitext(file_name)[0]
189 | return file_name, index
190 | else:
191 | return None, 0
192 |
193 |
194 | def broadcast_params(model):
195 | params = model.parameters()
196 | for param in params:
197 | dist.broadcast(param.data, src=0)
198 | dist.barrier()
199 | torch.cuda.synchronize()
200 |
201 |
202 | def dataparallel_and_sync(model, local_rank, find_unused_parameters=False):
203 | # DistributedDataParallel
204 | model = DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=find_unused_parameters)
205 |
206 | # broadcast
207 | broadcast_params(model)
208 |
209 | model = model.module
210 |
211 | return model
212 |
213 | def cleanup():
214 | dist.destroy_process_group()
215 |
216 | def get_rank():
217 | if not dist.is_available():
218 | return 0
219 |
220 | if not dist.is_initialized():
221 | return 0
222 |
223 | return dist.get_rank()
224 |
225 | def get_world_size():
226 | if not dist.is_available():
227 | return 1
228 |
229 | if not dist.is_initialized():
230 | return 1
231 |
232 | return dist.get_world_size()
233 |
234 | def synchronize():
235 | if not dist.is_available():
236 | return
237 |
238 | if not dist.is_initialized():
239 | return
240 |
241 | world_size = dist.get_world_size()
242 |
243 | if world_size == 1:
244 | return
245 |
246 | dist.barrier()
247 |
248 | def reduce_loss_dict(loss_dict):
249 | world_size = get_world_size()
250 |
251 | if world_size < 2:
252 | return loss_dict
253 |
254 | with torch.no_grad():
255 | keys = []
256 | losses = []
257 |
258 | for k in sorted(loss_dict.keys()):
259 | keys.append(k)
260 | losses.append(loss_dict[k])
261 |
262 | losses = torch.stack(losses, 0)
263 | dist.reduce(losses, dst=0)
264 |
265 | if dist.get_rank() == 0:
266 | losses /= world_size
267 |
268 | reduced_losses = {k: v.mean().item() for k, v in zip(keys, losses)}
269 |
270 | return reduced_losses
271 |
272 |
273 | def dict_to_numpy(x_dict, python_value=False):
274 | losses_numpy = {}
275 | for k,v in x_dict.items():
276 | losses_numpy[k] = tensor_to_numpy(v, python_value=python_value)
277 |
278 | return losses_numpy
279 |
280 | def tensor_to_numpy(x, python_value=False):
281 | if isinstance(x, torch.Tensor):
282 | if python_value:
283 | return x.detach().cpu().numpy().tolist()
284 | else:
285 | return x.detach().cpu().numpy()
286 | else:
287 | return x
288 |
289 | def get_val(x):
290 | x_val = x.mean().item()
291 |
292 | return x_val
293 |
294 | def requires_grad(model, flag=True):
295 | for p in model.parameters():
296 | p.requires_grad = flag
297 |
298 | def count_parameters(model):
299 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
300 |
--------------------------------------------------------------------------------
/src/ddpm_ddim/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/diffusion-pytorch/3c00400219e799ffe6554c3128743cc029f09507/src/ddpm_ddim/.DS_Store
--------------------------------------------------------------------------------
/src/ddpm_ddim/dataset/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/diffusion-pytorch/3c00400219e799ffe6554c3128743cc029f09507/src/ddpm_ddim/dataset/.DS_Store
--------------------------------------------------------------------------------
/src/ddpm_ddim/dataset/cat/flickr_cat_000008.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/diffusion-pytorch/3c00400219e799ffe6554c3128743cc029f09507/src/ddpm_ddim/dataset/cat/flickr_cat_000008.png
--------------------------------------------------------------------------------
/src/ddpm_ddim/dataset/cat/flickr_cat_000011.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/diffusion-pytorch/3c00400219e799ffe6554c3128743cc029f09507/src/ddpm_ddim/dataset/cat/flickr_cat_000011.png
--------------------------------------------------------------------------------
/src/ddpm_ddim/dataset/cat/flickr_cat_000016.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/diffusion-pytorch/3c00400219e799ffe6554c3128743cc029f09507/src/ddpm_ddim/dataset/cat/flickr_cat_000016.png
--------------------------------------------------------------------------------
/src/ddpm_ddim/dataset/cat/flickr_cat_000056.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/diffusion-pytorch/3c00400219e799ffe6554c3128743cc029f09507/src/ddpm_ddim/dataset/cat/flickr_cat_000056.png
--------------------------------------------------------------------------------
/src/ddpm_ddim/dataset/cat/flickr_cat_000076.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/diffusion-pytorch/3c00400219e799ffe6554c3128743cc029f09507/src/ddpm_ddim/dataset/cat/flickr_cat_000076.png
--------------------------------------------------------------------------------
/src/ddpm_ddim/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 | from matplotlib import pyplot as plt
5 | from tqdm import tqdm
6 | from torch import optim
7 | from utils import *
8 | from modules import UNet, linear_beta_schedule, cosine_beta_schedule
9 | import logging
10 | from torch.utils.tensorboard import SummaryWriter
11 | import torch.nn.functional as F
12 |
13 | logging.basicConfig(format="%(asctime)s - %(levelname)s: %(message)s", level=logging.INFO, datefmt="%I:%M:%S")
14 |
15 |
16 | class Diffusion:
17 | def __init__(self, noise_steps=1000, beta_start=1e-4, beta_end=0.02, img_size=256, objective='ddpm', schedule='linear', device="cuda"):
18 | self.noise_steps = noise_steps
19 | self.beta_start = beta_start
20 | self.beta_end = beta_end
21 | self.img_size = img_size
22 | self.device = device
23 |
24 | self.objective = objective
25 |
26 | self.beta = self.prepare_noise_schedule(schedule, beta_start, beta_end).to(device)
27 |
28 | self.alpha = 1. - self.beta
29 | self.alpha_hat = torch.cumprod(self.alpha, dim=0)
30 |
31 | def prepare_noise_schedule(self, schedule, beta_start, beta_end):
32 | if schedule == 'linear':
33 | return linear_beta_schedule(self.noise_steps, beta_start, beta_end)
34 | else:
35 | return cosine_beta_schedule(self.noise_steps)
36 |
37 | def noise_images(self, x, t):
38 | sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t])[:, None, None, None]
39 | sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat[t])[:, None, None, None]
40 | z = torch.randn_like(x)
41 | return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * z, z
42 |
43 | def sample_timesteps(self, n):
44 | t = torch.randint(low=1, high=self.noise_steps, size=(n,))
45 | return t
46 |
47 | def tensor_to_image(self, x):
48 | x = (x.clamp(-1, 1) + 1) / 2
49 | x = (x * 255).type(torch.uint8)
50 | return x
51 |
52 | def sample(self, model, n):
53 | # reverse process
54 | logging.info(f"Sampling {n} new images....")
55 | model.eval()
56 | with torch.no_grad():
57 | x = torch.randn((n, 3, self.img_size, self.img_size)).to(self.device)
58 | for i in tqdm(reversed(range(1, self.noise_steps))):
59 | t = (torch.ones(n, dtype=torch.long) * i).to(self.device)
60 |
61 | alpha = self.alpha[t][:, None, None, None]
62 | beta = self.beta[t][:, None, None, None]
63 | alpha_hat = self.alpha_hat[t][:, None, None, None]
64 | alpha_hat_prev = self.alpha_hat[t-1][:, None, None, None]
65 | beta_tilde = beta * (1 - alpha_hat_prev) / (1 - alpha_hat) # similar to beta
66 |
67 | predicted_noise = model(x, t)
68 | noise = torch.randn_like(x)
69 |
70 | if self.objective == 'ddpm':
71 | predict_x0 = 0
72 | direction_point = 1 / torch.sqrt(alpha) * (x - (beta / (torch.sqrt(1 - alpha_hat))) * predicted_noise)
73 | random_noise = torch.sqrt(beta_tilde) * noise
74 |
75 | x = predict_x0 + direction_point + random_noise
76 | else:
77 | predict_x0 = torch.sqrt(alpha_hat_prev) * (x - torch.sqrt(1 - alpha_hat) * predicted_noise) / torch.sqrt(alpha_hat)
78 | direction_point = torch.sqrt(1 - alpha_hat_prev) * predicted_noise
79 | random_noise = 0
80 |
81 | x = predict_x0 + direction_point + random_noise
82 |
83 | model.train()
84 | return torch.clamp(x, -1.0, 1.0)
85 |
86 |
87 | def train(args):
88 | setup_logging(args.run_name)
89 | device = args.device
90 | dataloader = get_data(args)
91 | model = UNet(device=device).to(device)
92 | optimizer = optim.AdamW(model.parameters(), lr=args.lr)
93 | mse = nn.MSELoss()
94 | diffusion = Diffusion(img_size=args.image_size, device=device)
95 | logger = SummaryWriter(os.path.join("logs", args.run_name))
96 | l = len(dataloader)
97 |
98 | for epoch in range(args.epochs):
99 | logging.info(f"Starting epoch {epoch}:")
100 | pbar = tqdm(dataloader)
101 | for i, images in enumerate(pbar):
102 | images = images.to(device)
103 | t = diffusion.sample_timesteps(images.shape[0]).to(device)
104 | x_t, noise = diffusion.noise_images(images, t)
105 | predicted_noise = model(x_t, t)
106 | loss = mse(noise, predicted_noise)
107 |
108 | optimizer.zero_grad()
109 | loss.backward()
110 | optimizer.step()
111 |
112 | pbar.set_postfix(MSE=loss.item())
113 | logger.add_scalar("MSE", loss.item(), global_step=epoch * l + i)
114 |
115 | sampled_images = diffusion.sample(model, n=images.shape[0])
116 | save_images(sampled_images, os.path.join("results", args.run_name, f"{epoch}.png"))
117 | torch.save(model.state_dict(), os.path.join("models", args.run_name, f"ckpt.pt"))
118 |
119 |
120 | def launch():
121 | import argparse
122 | parser = argparse.ArgumentParser()
123 | args = parser.parse_args()
124 | args.epochs = 100
125 | args.batch_size = 16
126 | args.image_size = 64
127 | args.objective = 'ddpm'
128 | args.schedule = 'linear'
129 | args.dataset_path = "../dataset/cat"
130 | args.device = "cuda"
131 | args.lr = 3e-4
132 |
133 | args.run_name = "diffusion_{}_{}".format(args.objective, args.schedule)
134 | train(args)
135 |
136 |
137 | if __name__ == '__main__':
138 | launch()
139 |
--------------------------------------------------------------------------------
/src/ddpm_ddim/main_template.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 | from matplotlib import pyplot as plt
5 | from tqdm import tqdm
6 | from torch import optim
7 | from utils import *
8 | from modules import UNet, linear_beta_schedule, cosine_beta_schedule
9 | import logging
10 | from torch.utils.tensorboard import SummaryWriter
11 | import torch.nn.functional as F
12 |
13 | logging.basicConfig(format="%(asctime)s - %(levelname)s: %(message)s", level=logging.INFO, datefmt="%I:%M:%S")
14 |
15 |
16 | class Diffusion:
17 | def __init__(self, noise_steps=1000, beta_start=1e-4, beta_end=0.02, img_size=256, objective='ddpm', schedule='linear', device="cuda"):
18 | self.noise_steps = noise_steps
19 | self.beta_start = beta_start
20 | self.beta_end = beta_end
21 | self.img_size = img_size
22 | self.device = device
23 |
24 | self.objective = objective
25 |
26 | self.beta = self.prepare_noise_schedule(schedule, beta_start, beta_end).to(device)
27 |
28 | """
29 | Step 1.
30 |
31 | self.alpha = ?
32 | self.alpha_hat = ?
33 |
34 | """
35 |
36 |
37 | def prepare_noise_schedule(self, schedule, beta_start, beta_end):
38 | if schedule == 'linear':
39 | return linear_beta_schedule(self.noise_steps, beta_start, beta_end)
40 | else:
41 | return cosine_beta_schedule(self.noise_steps)
42 |
43 | def sample_timesteps(self, n):
44 | """
45 | Step 2.
46 | n개의 랜덤한 timestep을 샘플링 하세요. range = [1, self.noise_steps]
47 |
48 | :param n: int
49 | :return: [n, ] shape을 갖고있을것입니다.
50 |
51 | 주의사항: timestep이니까, 값은 int형이어야 합니다.
52 |
53 | """
54 | return
55 |
56 | def noise_images(self, x, t):
57 | """
58 | Step 3.
59 | forward process를 작성하세요.
60 | -> 이미지에 noise를 입히는 과정입니다.
61 |
62 | return은 노이즈를 입힌 이미지와, 입혔던 노이즈를 리턴하세요 !! 총 2개입니다.
63 |
64 | :param x: [n, 3, img_size, img_size]
65 | :param t: [n, ]
66 | :return: [n, 3, img_size, img_size], [n, 3, img_size, img_size]
67 |
68 | """
69 | return
70 |
71 |
72 | def sample(self, model, n):
73 | """
74 | Step 5. 마지막!
75 | reverse process를 완성하세요.
76 |
77 | :param model: Unet
78 | :param n: batch_size
79 | :return: x: [n, 3, img_size, img_size]
80 | """
81 | logging.info(f"Sampling {n} new images....")
82 | model.eval()
83 | with torch.no_grad():
84 | """
85 | (1) T스텝에서 부터 denoise하는것이기때문에, 가우시안 noise를 하나 만드세요.
86 | (2) T (self.noise_steps)부터 denoise하는 구문을 만드세요.
87 | hint: T, T-1, T-2, ... , 3, 2, 1 이런식으로 t가 나와야겠죠 ?
88 | (3) t에 해당하는 alpha_t, beta_t, alpha_hat_t, alpha_hat_(t-1), beta_tilde를 만드세요.
89 |
90 | (4) (1)의 noise와 (2)의 t를 모델에 넣어서, noise를 predict하세요.
91 | (5) predict한 noise를 가지고, ddpm과 ddim sampling를 작성하세요.
92 |
93 | """
94 |
95 | model.train()
96 | return torch.clamp(x, -1.0, 1.0)
97 |
98 |
99 | def train(args):
100 | setup_logging(args.run_name)
101 | device = args.device
102 | dataloader = get_data(args)
103 | model = UNet(device=device).to(device)
104 | optimizer = optim.AdamW(model.parameters(), lr=args.lr)
105 | diffusion = Diffusion(img_size=args.image_size, device=device)
106 | logger = SummaryWriter(os.path.join("logs", args.run_name))
107 | l = len(dataloader)
108 |
109 | for epoch in range(args.epochs):
110 | logging.info(f"Starting epoch {epoch}:")
111 | pbar = tqdm(dataloader)
112 | for i, images in enumerate(pbar):
113 | images = images.to(device)
114 | """
115 | Step 4.
116 | 학습코드를 작성해보세요.
117 | 다음 hint를 참고하여 작성하면됩니다.
118 |
119 | hint:
120 | (1) timestep을 샘플링 하세요.
121 | (2) 해당 timestep t에 대응되는 노이즈 입힌 이미지를 만드세요.
122 | (3) 모델에 넣어서, 노이즈를 predict 하세요.
123 | (4) 적절한 loss를 선택하세요. (L1 or L2)
124 | """
125 |
126 | optimizer.zero_grad()
127 | loss.backward()
128 | optimizer.step()
129 |
130 | pbar.set_postfix(Loss=loss.item())
131 | logger.add_scalar("diffusion loss", loss.item(), global_step=epoch * l + i)
132 |
133 | sampled_images = diffusion.sample(model, n=images.shape[0])
134 | save_images(sampled_images, os.path.join("results", args.run_name, f"{epoch}.png"))
135 | torch.save(model.state_dict(), os.path.join("models", args.run_name, f"ckpt.pt"))
136 |
137 |
138 | def launch():
139 | import argparse
140 | parser = argparse.ArgumentParser()
141 | args = parser.parse_args()
142 | args.epochs = 100
143 | args.batch_size = 16
144 | args.image_size = 64
145 | args.objective = 'ddpm'
146 | args.schedule = 'linear'
147 | args.dataset_path = "../dataset/cat"
148 | args.device = "cpu"
149 | args.lr = 3e-4
150 |
151 | args.run_name = "diffusion_{}_{}".format(args.objective, args.schedule)
152 | train(args)
153 |
154 |
155 | if __name__ == '__main__':
156 | launch()
157 |
--------------------------------------------------------------------------------
/src/ddpm_ddim/modules.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import math
5 |
6 | def linear_beta_schedule(timesteps, beta_start, beta_end):
7 | scale = 1000 / timesteps
8 | beta_start = scale * beta_start
9 | beta_end = scale * beta_end
10 |
11 | return torch.linspace(beta_start, beta_end, timesteps)
12 |
13 | def cosine_beta_schedule(timesteps, s = 0.008):
14 | """
15 | cosine schedule
16 | as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
17 | """
18 | steps = timesteps + 1
19 | x = torch.linspace(0, timesteps, steps)
20 |
21 | alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2
22 | alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
23 | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
24 |
25 | return torch.clamp(betas, 0, 0.999)
26 |
27 | class EMA:
28 | def __init__(self, beta):
29 | super().__init__()
30 | self.beta = beta
31 | self.step = 0
32 |
33 | def update_model_average(self, ma_model, current_model):
34 | for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
35 | old_weight, up_weight = ma_params.data, current_params.data
36 | ma_params.data = self.update_average(old_weight, up_weight)
37 |
38 | def update_average(self, old, new):
39 | if old is None:
40 | return new
41 | return old * self.beta + (1 - self.beta) * new
42 |
43 | def step_ema(self, ema_model, model, step_start_ema=2000):
44 | if self.step < step_start_ema:
45 | self.reset_parameters(ema_model, model)
46 | self.step += 1
47 | return
48 | self.update_model_average(ema_model, model)
49 | self.step += 1
50 |
51 | def reset_parameters(self, ema_model, model):
52 | ema_model.load_state_dict(model.state_dict())
53 |
54 |
55 | class SelfAttention(nn.Module):
56 | def __init__(self, channels, size):
57 | super(SelfAttention, self).__init__()
58 | self.channels = channels
59 | self.size = size
60 | self.mha = nn.MultiheadAttention(channels, 4, batch_first=True)
61 | self.ln = nn.LayerNorm([channels])
62 | self.ff_self = nn.Sequential(
63 | nn.LayerNorm([channels]),
64 | nn.Linear(channels, channels),
65 | nn.GELU(),
66 | nn.Linear(channels, channels),
67 | )
68 |
69 | def forward(self, x):
70 | x = x.view(-1, self.channels, self.size * self.size).swapaxes(1, 2)
71 | x_ln = self.ln(x)
72 | attention_value, _ = self.mha(x_ln, x_ln, x_ln)
73 | attention_value = attention_value + x
74 | attention_value = self.ff_self(attention_value) + attention_value
75 | return attention_value.swapaxes(2, 1).view(-1, self.channels, self.size, self.size)
76 |
77 |
78 | class DoubleConv(nn.Module):
79 | def __init__(self, in_channels, out_channels, mid_channels=None, residual=False):
80 | super().__init__()
81 | self.residual = residual
82 | if not mid_channels:
83 | mid_channels = out_channels
84 | self.double_conv = nn.Sequential(
85 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
86 | nn.GroupNorm(1, mid_channels),
87 | nn.GELU(),
88 | nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
89 | nn.GroupNorm(1, out_channels),
90 | )
91 |
92 | def forward(self, x):
93 | if self.residual:
94 | return F.gelu(x + self.double_conv(x))
95 | else:
96 | return self.double_conv(x)
97 |
98 |
99 | class Down(nn.Module):
100 | def __init__(self, in_channels, out_channels, emb_dim=256):
101 | super().__init__()
102 | self.maxpool_conv = nn.Sequential(
103 | nn.MaxPool2d(2),
104 | DoubleConv(in_channels, in_channels, residual=True),
105 | DoubleConv(in_channels, out_channels),
106 | )
107 |
108 | self.emb_layer = nn.Sequential(
109 | nn.SiLU(),
110 | nn.Linear(
111 | emb_dim,
112 | out_channels
113 | ),
114 | )
115 |
116 | def forward(self, x, t):
117 | x = self.maxpool_conv(x)
118 | emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
119 | return x + emb
120 |
121 |
122 | class Up(nn.Module):
123 | def __init__(self, in_channels, out_channels, emb_dim=256):
124 | super().__init__()
125 |
126 | self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
127 | self.conv = nn.Sequential(
128 | DoubleConv(in_channels, in_channels, residual=True),
129 | DoubleConv(in_channels, out_channels, in_channels // 2),
130 | )
131 |
132 | self.emb_layer = nn.Sequential(
133 | nn.SiLU(),
134 | nn.Linear(
135 | emb_dim,
136 | out_channels
137 | ),
138 | )
139 |
140 | def forward(self, x, skip_x, t):
141 | x = self.up(x)
142 | x = torch.cat([skip_x, x], dim=1)
143 | x = self.conv(x)
144 | emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
145 | return x + emb
146 |
147 |
148 | class UNet(nn.Module):
149 | def __init__(self, c_in=3, c_out=3, time_dim=256, device="cuda"):
150 | super().__init__()
151 | self.device = device
152 | self.time_dim = time_dim
153 | self.inc = DoubleConv(c_in, 64)
154 | self.down1 = Down(64, 128)
155 | self.sa1 = SelfAttention(128, 32)
156 | self.down2 = Down(128, 256)
157 | self.sa2 = SelfAttention(256, 16)
158 | self.down3 = Down(256, 256)
159 | self.sa3 = SelfAttention(256, 8)
160 |
161 | self.bot1 = DoubleConv(256, 512)
162 | self.bot2 = DoubleConv(512, 512)
163 | self.bot3 = DoubleConv(512, 256)
164 |
165 | self.up1 = Up(512, 128)
166 | self.sa4 = SelfAttention(128, 16)
167 | self.up2 = Up(256, 64)
168 | self.sa5 = SelfAttention(64, 32)
169 | self.up3 = Up(128, 64)
170 | self.sa6 = SelfAttention(64, 64)
171 | self.outc = nn.Conv2d(64, c_out, kernel_size=1)
172 |
173 | def pos_encoding(self, t, channels):
174 | inv_freq = 1.0 / (10000** (torch.arange(0, channels, 2, device=self.device).float() / channels))
175 | pos_enc_a = torch.sin(t.repeat(1, channels // 2) * inv_freq)
176 | pos_enc_b = torch.cos(t.repeat(1, channels // 2) * inv_freq)
177 | pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1)
178 | return pos_enc
179 |
180 | def forward(self, x, t):
181 | t = t.unsqueeze(-1).type(torch.float)
182 | t = self.pos_encoding(t, self.time_dim)
183 |
184 | x1 = self.inc(x)
185 | x2 = self.down1(x1, t)
186 | x2 = self.sa1(x2)
187 | x3 = self.down2(x2, t)
188 | x3 = self.sa2(x3)
189 | x4 = self.down3(x3, t)
190 | x4 = self.sa3(x4)
191 |
192 | x4 = self.bot1(x4)
193 | x4 = self.bot2(x4)
194 | x4 = self.bot3(x4)
195 |
196 | x = self.up1(x4, x3, t)
197 | x = self.sa4(x)
198 | x = self.up2(x, x2, t)
199 | x = self.sa5(x)
200 | x = self.up3(x, x1, t)
201 | x = self.sa6(x)
202 | output = self.outc(x)
203 | return output
204 |
205 |
206 | class UNet_conditional(nn.Module):
207 | def __init__(self, c_in=3, c_out=3, time_dim=256, num_classes=None, device="cuda"):
208 | super().__init__()
209 | self.device = device
210 | self.time_dim = time_dim
211 | self.inc = DoubleConv(c_in, 64)
212 | self.down1 = Down(64, 128)
213 | self.sa1 = SelfAttention(128, 32)
214 | self.down2 = Down(128, 256)
215 | self.sa2 = SelfAttention(256, 16)
216 | self.down3 = Down(256, 256)
217 | self.sa3 = SelfAttention(256, 8)
218 |
219 | self.bot1 = DoubleConv(256, 512)
220 | self.bot2 = DoubleConv(512, 512)
221 | self.bot3 = DoubleConv(512, 256)
222 |
223 | self.up1 = Up(512, 128)
224 | self.sa4 = SelfAttention(128, 16)
225 | self.up2 = Up(256, 64)
226 | self.sa5 = SelfAttention(64, 32)
227 | self.up3 = Up(128, 64)
228 | self.sa6 = SelfAttention(64, 64)
229 | self.outc = nn.Conv2d(64, c_out, kernel_size=1)
230 |
231 | if num_classes is not None:
232 | self.label_emb = nn.Embedding(num_classes, time_dim)
233 |
234 | def pos_encoding(self, t, channels):
235 | inv_freq = 1.0 / (
236 | 10000
237 | ** (torch.arange(0, channels, 2, device=self.device).float() / channels)
238 | )
239 | pos_enc_a = torch.sin(t.repeat(1, channels // 2) * inv_freq)
240 | pos_enc_b = torch.cos(t.repeat(1, channels // 2) * inv_freq)
241 | pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1)
242 | return pos_enc
243 |
244 | def forward(self, x, t, y):
245 | t = t.unsqueeze(-1).type(torch.float)
246 | t = self.pos_encoding(t, self.time_dim)
247 |
248 | if y is not None:
249 | t += self.label_emb(y)
250 |
251 | x1 = self.inc(x)
252 | x2 = self.down1(x1, t)
253 | x2 = self.sa1(x2)
254 | x3 = self.down2(x2, t)
255 | x3 = self.sa2(x3)
256 | x4 = self.down3(x3, t)
257 | x4 = self.sa3(x4)
258 |
259 | x4 = self.bot1(x4)
260 | x4 = self.bot2(x4)
261 | x4 = self.bot3(x4)
262 |
263 | x = self.up1(x4, x3, t)
264 | x = self.sa4(x)
265 | x = self.up2(x, x2, t)
266 | x = self.sa5(x)
267 | x = self.up3(x, x1, t)
268 | x = self.sa6(x)
269 | output = self.outc(x)
270 | return output
271 |
--------------------------------------------------------------------------------
/src/ddpm_ddim/noise.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/diffusion-pytorch/3c00400219e799ffe6554c3128743cc029f09507/src/ddpm_ddim/noise.jpg
--------------------------------------------------------------------------------
/src/ddpm_ddim/noise_test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchvision.utils import save_image
3 | from main import Diffusion
4 | from utils import get_data
5 | import argparse
6 |
7 | parser = argparse.ArgumentParser()
8 | args = parser.parse_args()
9 | args.batch_size = 1 # 5
10 | args.image_size = 64
11 | args.dataset_path = '../dataset/cat'
12 |
13 | dataloader = get_data(args)
14 |
15 | diff = Diffusion(device="cpu")
16 |
17 | image = next(iter(dataloader))[0]
18 | t = torch.Tensor([50, 100, 150, 200, 300, 600, 700, 999]).long()
19 |
20 | noised_image, _ = diff.noise_images(image, t)
21 | save_image(noised_image.add(1).mul(0.5), "noise.jpg")
--------------------------------------------------------------------------------
/src/ddpm_ddim/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torchvision
4 | from PIL import Image
5 | from matplotlib import pyplot as plt
6 | from torch.utils.data import DataLoader
7 | from torchvision import transforms
8 | from torch.utils.data import Dataset
9 | from glob import glob
10 |
11 | def plot_images(images):
12 | plt.figure(figsize=(32, 32))
13 | plt.imshow(torch.cat([
14 | torch.cat([i for i in images.cpu()], dim=-1),
15 | ], dim=-2).permute(1, 2, 0).cpu())
16 | plt.show()
17 |
18 |
19 | def save_images(images, path):
20 | torchvision.utils.save_image(images, path,
21 | nrow=4,
22 | normalize=True, range=(-1, 1))
23 |
24 |
25 | def get_data(args):
26 | dataset = ImageDataset(img_size=args.image_size, dataset_path=args.dataset_path)
27 | dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
28 | return dataloader
29 |
30 | class ImageDataset(Dataset):
31 | def __init__(self, img_size, dataset_path):
32 | self.train_images = self.listdir(dataset_path)
33 |
34 | transform_list = [
35 | transforms.Resize(size=[img_size, img_size]),
36 | transforms.RandomHorizontalFlip(p=0.5),
37 | transforms.ToTensor(), # [0, 255] -> [0, 1]
38 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), # [0, 1] -> [-1, 1]
39 | ]
40 |
41 | self.transform = transforms.Compose(transform_list)
42 |
43 | def listdir(self, dir_path):
44 | extensions = ['png', 'jpg', 'jpeg', 'JPG']
45 | file_path = []
46 | for ext in extensions:
47 | file_path += glob(os.path.join(dir_path, '*.' + ext))
48 | file_path.sort()
49 | return file_path
50 |
51 | def __getitem__(self, index):
52 | sample_path = self.train_images[index]
53 | img = Image.open(sample_path).convert('RGB')
54 | img = self.transform(img)
55 |
56 |
57 | return img
58 |
59 | def __len__(self):
60 | return len(self.train_images)
61 |
62 | def setup_logging(run_name):
63 | os.makedirs("models", exist_ok=True)
64 | os.makedirs("results", exist_ok=True)
65 | os.makedirs(os.path.join("models", run_name), exist_ok=True)
66 | os.makedirs(os.path.join("results", run_name), exist_ok=True)
--------------------------------------------------------------------------------
/src/evaluation/clipscore.py:
--------------------------------------------------------------------------------
1 | from torchmetrics.multimodal import CLIPScore
2 | from transformers import CLIPTokenizer, CLIPTextModel, CLIPVisionModel, CLIPModel
3 | import torch
4 | from torchvision import transforms
5 | import torch.nn.functional as F
6 | import torch.nn as nn
7 |
8 | def clip_image_process(x):
9 | def denormalize(x):
10 | # [-1, 1] ~ [0, 255]
11 | x = ((x + 1) / 2 * 255).clamp(0, 255).to(torch.uint8)
12 |
13 | return x
14 |
15 | def resize(x):
16 | x = transforms.Resize(size=[224, 224], interpolation=transforms.InterpolationMode.BICUBIC, antialias=True)(x)
17 | return x
18 |
19 | def zero_to_one(x):
20 | x = x.float() / 255.0
21 | return x
22 |
23 | def norm_mean_std(x):
24 | mean = [0.48145466, 0.4578275, 0.40821073]
25 | std = [0.26862954, 0.26130258, 0.27577711]
26 | x = transforms.Normalize(mean=mean, std=std, inplace=True)(x)
27 | return x
28 |
29 | # 만약 x가 [-1, 1] 이면, denorm을 해줍니다.
30 | # x = denormalize(x)
31 | x = resize(x)
32 | x = zero_to_one(x)
33 | x = norm_mean_std(x)
34 |
35 | return x
36 |
37 | def contrastive_loss(logits, dim) :
38 | neg_ce = torch.diag(nn.functional.log_softmax(logits, dim=dim))
39 | return -neg_ce.mean()
40 |
41 | def clip_contra_loss(img_features, txt_features, logit_scale):
42 | img_features = img_features / img_features.norm(p=2, dim=-1, keepdim=True)
43 | txt_features = txt_features / txt_features.norm(p=2, dim=-1, keepdim=True)
44 |
45 | # cosine similarity as logits
46 | logit_scale = logit_scale.exp()
47 | similarity = torch.matmul(txt_features, img_features.t()) * logit_scale
48 |
49 | caption_loss = contrastive_loss(similarity, dim=0)
50 | image_loss = contrastive_loss(similarity, dim=1)
51 |
52 | return (caption_loss + image_loss) / 2.0 # minimize
53 |
54 | def clip_score(img_features, txt_features):
55 | img_features = img_features / img_features.norm(p=2, dim=-1, keepdim=True)
56 | txt_features = txt_features / txt_features.norm(p=2, dim=-1, keepdim=True)
57 |
58 | # score = 100 * (img_features * txt_features).sum(axis=-1)
59 | # score = torch.mean(score)
60 |
61 | # 위와 같다.
62 | score = F.cosine_similarity(img_features, txt_features).mean()
63 | return score
64 |
65 | # library
66 | image = torch.randint(255, (2, 3, 224, 224))
67 | text = ["a photo of a cat", "a photo of a cat"]
68 | version = 'openai/clip-vit-large-patch14'
69 | metric = CLIPScore(model_name_or_path=version)
70 | score = metric(image, text)
71 | print(score)
72 |
73 | """
74 | Step 1. Model Init
75 | """
76 | tokenizer = CLIPTokenizer.from_pretrained(version)
77 | clip_text_encoder = CLIPTextModel.from_pretrained(version)
78 | clip_image_encoder = CLIPVisionModel.from_pretrained(version)
79 | clip_model = CLIPModel.from_pretrained(version)
80 |
81 | """
82 | Step 2. Text
83 | """
84 | batch_encoding = tokenizer(text, truncation=True, max_length=77, padding="max_length", return_tensors="pt")
85 | # [input_ids, attention_mask] -> 둘다 [bs,77]의 shape을 갖고있습니다.
86 | # input_ids는 주어진 텍스트를 토크나이즈한것이고, mask는 어디까지만이 유효한 token인지 알려줍니다. 1=유효, 0=의미없음
87 |
88 | text_token = batch_encoding["input_ids"]
89 | t_embed = clip_text_encoder(text_token) # 이것은 clip_model.text_model(text_token)과 같다.
90 | # [last_hidden_state, pooler_output] -> [bs, 77, 768], [bs, 768]
91 | # last_hidden_state = word embedding
92 | # pooler_output = sentence embedding
93 |
94 | text_feature = clip_model.get_text_features(text_token)
95 | # pooler_output(sentence embedding) 에 Linear를 태운것
96 | # [bs, 768]
97 |
98 | """
99 | Step 3. Image
100 | """
101 | image = clip_image_process(image)
102 |
103 | i_embed = clip_image_encoder(image) # 이것은 clip_model.vision_model(image)과 같다.
104 | # [last_hidden_state, pooler_output] -> [bs, 256, 1024], [bs, 1024]
105 |
106 | image_feature = clip_model.get_image_features(image)
107 | # pooler_output에 Linear을 태운것
108 |
109 | print(clip_score(image_feature, text_feature))
110 | print(clip_contra_loss(image_feature, text_feature, clip_model.logit_scale))
111 |
112 |
--------------------------------------------------------------------------------
/src/evaluation/data_loader.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from itertools import chain
3 | from PIL import Image
4 | from torch.utils import data
5 | from torchvision import transforms
6 |
7 |
8 | def listdir(dname):
9 | fnames = list(chain(*[list(Path(dname).rglob('*.' + ext))
10 | for ext in ['png', 'jpg', 'jpeg', 'JPG']]))
11 | return fnames
12 |
13 |
14 | class DefaultDataset(data.Dataset):
15 | def __init__(self, root, transform=None):
16 | self.samples = listdir(root)
17 | self.samples.sort()
18 | self.transform = transform
19 | self.targets = None
20 |
21 | def __getitem__(self, index):
22 | fname = self.samples[index]
23 | img = Image.open(fname).convert('RGB')
24 | if self.transform is not None:
25 | img = self.transform(img)
26 | return img
27 |
28 | def __len__(self):
29 | return len(self.samples)
30 |
31 |
32 |
33 | def get_eval_loader(root, img_size=256, batch_size=32,
34 | imagenet_normalize=True, shuffle=True,
35 | num_workers=4, drop_last=False):
36 | print('Preparing DataLoader for the evaluation phase...')
37 | if imagenet_normalize:
38 | height, width = 299, 299
39 | mean = [0.485, 0.456, 0.406]
40 | std = [0.229, 0.224, 0.225]
41 | else:
42 | height, width = img_size, img_size
43 | mean = [0.5, 0.5, 0.5]
44 | std = [0.5, 0.5, 0.5]
45 |
46 | transform = transforms.Compose([
47 | transforms.Resize([img_size, img_size]),
48 | transforms.Resize([height, width]),
49 | transforms.ToTensor(),
50 | transforms.Normalize(mean=mean, std=std)
51 | ])
52 |
53 | dataset = DefaultDataset(root, transform=transform)
54 | return data.DataLoader(dataset=dataset,
55 | batch_size=batch_size,
56 | shuffle=shuffle,
57 | num_workers=num_workers,
58 | pin_memory=True,
59 | drop_last=drop_last)
60 |
--------------------------------------------------------------------------------
/src/evaluation/fid.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import torch
4 | import torch.nn as nn
5 | import numpy as np
6 | from torchvision import models
7 | from scipy import linalg
8 | from data_loader import get_eval_loader
9 | from tqdm import tqdm
10 |
11 |
12 | class InceptionV3(nn.Module):
13 | def __init__(self):
14 | super().__init__()
15 | inception = models.inception_v3(pretrained=True)
16 | self.block1 = nn.Sequential(
17 | inception.Conv2d_1a_3x3, inception.Conv2d_2a_3x3,
18 | inception.Conv2d_2b_3x3,
19 | nn.MaxPool2d(kernel_size=3, stride=2))
20 | self.block2 = nn.Sequential(
21 | inception.Conv2d_3b_1x1, inception.Conv2d_4a_3x3,
22 | nn.MaxPool2d(kernel_size=3, stride=2))
23 | self.block3 = nn.Sequential(
24 | inception.Mixed_5b, inception.Mixed_5c,
25 | inception.Mixed_5d, inception.Mixed_6a,
26 | inception.Mixed_6b, inception.Mixed_6c,
27 | inception.Mixed_6d, inception.Mixed_6e)
28 | self.block4 = nn.Sequential(
29 | inception.Mixed_7a, inception.Mixed_7b,
30 | inception.Mixed_7c,
31 | nn.AdaptiveAvgPool2d(output_size=(1, 1)))
32 |
33 | def forward(self, x):
34 | x = self.block1(x)
35 | x = self.block2(x)
36 | x = self.block3(x)
37 | x = self.block4(x)
38 | return x.view(x.size(0), -1)
39 |
40 |
41 | def frechet_distance(mu, cov, mu2, cov2):
42 | cc, _ = linalg.sqrtm(np.dot(cov, cov2), disp=False)
43 | dist = np.sum((mu -mu2)**2) + np.trace(cov + cov2 - 2*cc)
44 | return np.real(dist)
45 |
46 |
47 | @torch.no_grad()
48 | def calculate_fid_given_paths(paths, img_size=256, batch_size=50):
49 | print('Calculating FID given paths %s and %s...' % (paths[0], paths[1]))
50 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
51 | inception = InceptionV3().eval().to(device)
52 | loaders = [get_eval_loader(path, img_size, batch_size) for path in paths]
53 |
54 | mu, cov = [], []
55 | for loader in loaders:
56 | actvs = []
57 | for x in tqdm(loader, total=len(loader)):
58 | actv = inception(x.to(device))
59 | actvs.append(actv)
60 | actvs = torch.cat(actvs, dim=0).cpu().detach().numpy()
61 | mu.append(np.mean(actvs, axis=0))
62 | cov.append(np.cov(actvs, rowvar=False))
63 | fid_value = frechet_distance(mu[0], cov[0], mu[1], cov[1])
64 | return fid_value
65 |
66 |
67 | if __name__ == '__main__':
68 | parser = argparse.ArgumentParser()
69 | parser.add_argument('--paths', type=str, nargs=2, help='paths to real and fake images')
70 | parser.add_argument('--img_size', type=int, default=256, help='image resolution')
71 | parser.add_argument('--batch_size', type=int, default=64, help='batch size to use')
72 | args = parser.parse_args()
73 | fid_value = calculate_fid_given_paths(args.paths, args.img_size, args.batch_size)
74 | print('FID: ', fid_value)
75 |
76 | # python fid.py --paths PATH_REAL PATH_FAKE
--------------------------------------------------------------------------------
/src/stable_diffusion/sd_main.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import torch
3 | from transformers import CLIPTextModel, CLIPTokenizer
4 | from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler
5 | from tqdm.auto import tqdm
6 |
7 | # pip install diffusers
8 |
9 | model_ckpt = "CompVis/stable-diffusion-v1-4"
10 | torch_device = "cpu"
11 |
12 | # init
13 | vae = AutoencoderKL.from_pretrained(model_ckpt, subfolder="vae")
14 | tokenizer = CLIPTokenizer.from_pretrained(model_ckpt, subfolder="tokenizer")
15 | text_encoder = CLIPTextModel.from_pretrained(model_ckpt, subfolder="text_encoder")
16 | unet = UNet2DConditionModel.from_pretrained(model_ckpt, subfolder="unet")
17 | scheduler = PNDMScheduler.from_pretrained(model_ckpt, subfolder="scheduler")
18 |
19 | # device
20 | vae.to(torch_device)
21 | text_encoder.to(torch_device)
22 | unet.to(torch_device)
23 |
24 | # parameter
25 | prompt = ["a photograph of an astronaut riding a horse"]
26 | height = 512 # default height of Stable Diffusion
27 | width = 512 # default width of Stable Diffusion
28 | num_inference_steps = 25 # Number of denoising steps
29 | guidance_scale = 7.5 # Scale for classifier-free guidance
30 | generator = torch.manual_seed(0) # Seed generator to create the inital latent noise
31 | batch_size = len(prompt)
32 | scheduler.set_timesteps(num_inference_steps)
33 | print(scheduler.timesteps)
34 |
35 | """
36 | Step 1.
37 | Make text embeddings
38 | """
39 | text_input = tokenizer(
40 | prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt"
41 | )
42 |
43 | with torch.no_grad():
44 | text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0]
45 |
46 | uncond_input = tokenizer([""] * batch_size, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt")
47 | uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]
48 |
49 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
50 |
51 | """
52 | Step 2.
53 | Reverse process
54 | """
55 | # Create random noise
56 | latents = torch.randn(
57 | (batch_size, unet.config.in_channels, height // 8, width // 8),
58 | generator=generator,
59 | )
60 | latents = latents.to(torch_device)
61 | latents = latents * scheduler.init_noise_sigma # PNDMS = 1
62 |
63 | for t in tqdm(scheduler.timesteps):
64 | # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
65 | latent_model_input = torch.cat([latents] * 2)
66 |
67 | latent_model_input = scheduler.scale_model_input(latent_model_input, timestep=t)
68 |
69 | # predict the noise residual
70 | with torch.no_grad():
71 | noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
72 |
73 | # perform guidance
74 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
75 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
76 |
77 | # compute the previous noisy sample x_t -> x_t-1
78 | latents = scheduler.step(noise_pred, t, latents).prev_sample
79 |
80 |
81 | """
82 | Step 3.
83 | Image decoding
84 | """
85 | latents = 1 / 0.18215 * latents
86 | with torch.no_grad():
87 | image = vae.decode(latents).sample
88 |
89 | image = (image / 2 + 0.5).clamp(0, 1)
90 | image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
91 | images = (image * 255).round().astype("uint8")
92 | pil_images = [Image.fromarray(image) for image in images]
93 | pil_images[0].save("main_results.png")
--------------------------------------------------------------------------------
/src/stable_diffusion/sd_simple_main.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from diffusers import StableDiffusionPipeline
3 |
4 | # pip install diffusers
5 |
6 | model_ckpt = "CompVis/stable-diffusion-v1-4"
7 | device = "mps" # cuda, cpu, mps
8 | weight_dtype = torch.float16
9 |
10 | pipe = StableDiffusionPipeline.from_pretrained(model_ckpt, torch_dtype=weight_dtype)
11 | pipe = pipe.to(device)
12 |
13 | prompt = "a photograph of an astronaut riding a horse"
14 | image = pipe(prompt).images[0]
15 | image.save("simple_results.png")
--------------------------------------------------------------------------------