├── .gitignore
├── FGD.py
├── README.md
├── cbilateral.py
├── demo.ipynb
├── diffusionModel.py
├── figures
├── FGDTeaser.jpg
├── FGDThumbnailYTLarge.jpg
└── teaser_updated.png
├── imgs
└── red_hat.jpg
└── requirements.txt
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 | */.ipynb_checkpoints
81 |
82 | # IPython
83 | profile_default/
84 | ipython_config.py
85 |
86 | # pyenv
87 | # For a library or package, you might want to ignore these files since the code is
88 | # intended to run in multiple environments; otherwise, check them in:
89 | # .python-version
90 |
91 | # pipenv
92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
95 | # install all needed dependencies.
96 | #Pipfile.lock
97 |
98 | # poetry
99 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
100 | # This is especially recommended for binary packages to ensure reproducibility, and is more
101 | # commonly ignored for libraries.
102 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
103 | #poetry.lock
104 |
105 | # pdm
106 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
107 | #pdm.lock
108 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
109 | # in version control.
110 | # https://pdm.fming.dev/#use-with-ide
111 | .pdm.toml
112 |
113 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
114 | __pypackages__/
115 |
116 | # Celery stuff
117 | celerybeat-schedule
118 | celerybeat.pid
119 |
120 | # SageMath parsed files
121 | *.sage.py
122 |
123 | # Environments
124 | .env
125 | .venv
126 | env/
127 | venv/
128 | ENV/
129 | env.bak/
130 | venv.bak/
131 |
132 | # Spyder project settings
133 | .spyderproject
134 | .spyproject
135 |
136 | # Rope project settings
137 | .ropeproject
138 |
139 | # mkdocs documentation
140 | /site
141 |
142 | # mypy
143 | .mypy_cache/
144 | .dmypy.json
145 | dmypy.json
146 |
147 | # Pyre type checker
148 | .pyre/
149 |
150 | # pytype static type analyzer
151 | .pytype/
152 |
153 | # Cython debug symbols
154 | cython_debug/
155 |
156 | # PyCharm
157 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
158 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
159 | # and can be added to the global gitignore or merged into this file. For a more nuclear
160 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
161 | #.idea/
162 |
163 | #pytorch
164 | *.ckpt
165 |
166 | # misc
167 | .ipynb_checkpoints
168 | __pycache__
169 | FGD_release
--------------------------------------------------------------------------------
/FGD.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from cbilateral import getCrossBilateralMatrix4D
3 | import json
4 | import numpy as np
5 |
6 | device = "cuda" if torch.cuda.is_available() else "cpu"
7 | class FGD():
8 | def __init__(self, diffusionModel, guide_image, detail=1.2, sigmas=[3,3,0.3], t_end=15, norm_steps=0):
9 | self.guide_image = guide_image
10 | self.detail = detail
11 | self.t_end = t_end
12 | self.sigmas = sigmas
13 | self.norm_steps = norm_steps
14 | self.model = diffusionModel
15 | self.bilateral_matrix_4d = None
16 |
17 | self.guide_latent = None
18 | self.guide_structure = None
19 | self.guide_structure_normalized = None
20 |
21 | self.init_guide_latent = None
22 | self.init_guide_structure = None
23 | self.init_guide_stucture_normalized=None
24 |
25 | self.init_bilateral_matrix_4d=None
26 | self.set_guide_image(guide_image)
27 |
28 | def set_ST(self, detail=1.6, recompute_matrix=True, sigmas=[3,3,0.3]):
29 | if recompute_matrix:
30 | self.set_bilateral_matrix(sigmas)
31 | self.detail = detail
32 | self.t_end = 15
33 | self.norm_steps = 50
34 |
35 |
36 | def reset(self):
37 | self.init_guide_latent = self.guide_latent
38 | self.guide_structure = self.init_guide_structure
39 | self.guide_structure_normalized = self.init_guide_structure_normalized
40 | self.bilateral_matrix_4d = self.init_bilateral_matrix_4d
41 |
42 | def set_guide_image(self, guide_image):
43 | self.guide_latent = self.model.encode_image(guide_image)
44 | self.guide_image = guide_image
45 | if self.sigmas != None:
46 | self.set_bilateral_matrix(self.sigmas)
47 |
48 | def set_bilateral_matrix(self,sigmas):
49 | assert len(sigmas)==2 or len(sigmas)==3, "sigmas has invalid number of entries (either 2 or 3)"
50 | sigmas = np.array(sigmas).astype(np.double)
51 | if len(sigmas) == 2:
52 | sigmas = np.insert(sigmas, 1, sigmas[0])
53 |
54 | guide_latent_processed = self.guide_latent.detach().cpu().permute(0, 2, 3, 1).numpy()
55 | guide_latent_processed = np.squeeze(guide_latent_processed)
56 | bilateral_matrix = getCrossBilateralMatrix4D(guide_latent_processed.astype('double'),sigmas)
57 | self.bilateral_matrix_4d = torch.Tensor(bilateral_matrix).unsqueeze(0).repeat((4,1,1)).to(device)
58 | guide_structure_latent = torch.matmul(self.bilateral_matrix_4d, self.guide_latent.reshape(4,4096,1))
59 | guide_structure_latent = guide_structure_latent.reshape(1,4,64,64)
60 |
61 | guide_mean = torch.mean(guide_structure_latent, (2,3), keepdim=True)
62 | guide_std = torch.std(guide_structure_latent, (2,3), keepdim=True)
63 |
64 | self.guide_structure_normalized = (guide_structure_latent - guide_mean) / guide_std
65 | self.guide_structure = guide_structure_latent
66 |
67 | self.init_guide_structure = self.guide_structure
68 | self.init_guide_structure_normalized=self.guide_structure_normalized
69 | self.init_bilateral_matrix_4d = self.bilateral_matrix_4d
70 |
71 | self.sigmas = sigmas.tolist()
72 |
73 | def get_residual_structure(self, latents):
74 | current_structure = torch.matmul(self.bilateral_matrix_4d, latents.reshape(4,4096,1))
75 | current_structure = current_structure.reshape(1,4,64,64)
76 |
77 | d_structure = self.guide_structure - current_structure
78 | return d_structure
79 |
80 | def get_structure(self, latents, bm_4d=None):
81 | if bm_4d ==None:
82 | bm_4d = self.bilateral_matrix_4d
83 | structure = torch.matmul(bm_4d, latents.reshape(4,4096,1))
84 | structure = structure.reshape(1,4,64,64)
85 | return structure
86 |
87 | def get_guidance(self, latents, input_latents, scheduler, t):
88 | guide_low = self.guide_structure
89 |
90 | st_low = self.get_structure(latents)
91 | st_high = latents - st_low
92 |
93 | weight= self.detail
94 |
95 | d = guide_low - st_low
96 |
97 | return weight, d
98 |
99 |
100 | def get_guidance_normalized(self, latents, input_latents, scheduler, t):
101 | current_structure = self.get_structure(latents)
102 | guide_structure = self.guide_structure
103 |
104 | current_mean = torch.mean(current_structure, (2,3), keepdim=True)
105 | current_std = torch.std(current_structure, (2,3), keepdim=True)
106 |
107 | guide_structure_renormalized = self.guide_structure_normalized * current_std + current_mean
108 | d_structure_renormalized = guide_structure_renormalized - current_structure
109 |
110 | residual_score = torch.mean(torch.abs(d_structure_renormalized))
111 |
112 | weight = self.detail
113 |
114 | return weight, d_structure_renormalized
115 |
116 | def get_params(self):
117 | params = {
118 | 'guide image':self.guide_image,
119 | 'detail':self.detail,
120 | 'sigmas':self.sigmas,
121 | 't_end':self.t_end,
122 | 'norm steps':self.norm_steps,
123 | }
124 | return params
125 | def __str__(self):
126 | return (json.dumps(self.get_params(), indent=2))
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Filtered-Guided Diffusion for Controllable Image Generation
2 | *Fast, lightweight, architecture-independent, low-level frequency control for diffusion-based Image-to-Image translations.*
3 |
4 | [**Filtered-Guided Diffusion for Controllable Image Generation**](http://filterguideddiffusion.github.io/)
5 | [Zeqi Gu*](https://github.com/jaclyngu),
6 | [Ethan Yang*](https://www.cs.cornell.edu/abe/group/members),
7 | [Abe Davis](http://abedavis.com/)
8 | \* denotes equal Contribution
9 | _[GitHub](https://github.com/jaclyngu/FilteredGuidedDiffusion) | [Paper](https://dl.acm.org/doi/10.1145/3641519.3657489) | [Project Page](http://filterguideddiffusion.github.io)_
10 |
11 |
12 | 
13 |
14 |
17 |
18 | ## Code (now released!)
19 | ### Summary
20 | We provide a lightweight implementation of FGD which contains all the core functionality described in our paper. Our code is based on the [🤗 diffusers library](https://huggingface.co/docs/diffusers/en/index) and [taichi lang](https://www.taichi-lang.org/) for efficient computation of the cross bilateral matrix.
21 |
22 | A full explanation of how to use our code is described in the jupyter notebook demo.ipynb.
23 |
24 | For questions regarding the code, or access to a more comprehensive set of fuctions (although much less user friendly) including experimental features, debugging, and evaluation, please contact both authors at zg45@cornell.edu and eey8@cornell.edu.
25 |
26 | ### Setup
27 | We provide a requirements.txt file which contains the packages our implementation of FGD was tested on. Note running our code requires a GPU.
28 | ```
29 | diffusers==0.30.0
30 | numpy==2.0.1
31 | Pillow==10.4.0
32 | pytorch_lightning==2.4.0
33 | taichi==1.7.1
34 | torch==2.4.0+cu118
35 | tqdm==4.66.5
36 | transformers==4.44.0
37 | ```
38 | **Note:** jupyter notebook is also required in order to run our demo as we do not provide a command line interface.
39 |
40 | ## Citation
41 | For those wishing to use our work, please use the following citation:
42 | ```
43 | @inproceedings{gu2024filter,
44 | title={Filter-Guided Diffusion for Controllable Image Generation},
45 | author={Gu, Zeqi and Yang, Ethan and Davis, Abe},
46 | booktitle={ACM SIGGRAPH 2024 Conference Papers},
47 | pages={1--10},
48 | year={2024}
49 | }
50 | ```
--------------------------------------------------------------------------------
/cbilateral.py:
--------------------------------------------------------------------------------
1 | import taichi as ti
2 | import numpy as np
3 |
4 | ti.init(arch=ti.gpu, default_fp=ti.f64)
5 | EPSILON = 1e-5
6 |
7 | @ti.func
8 | def nGaussExp(x: ti.f64, sigma: ti.f64):
9 | # Special case if sigma is 0. In this case we should have a dirac delta.
10 | value = 0.0
11 | if(sigma filter.t_end:
177 | st_filtered += d*weight
178 | else:
179 | st_filtered = st
180 |
181 | assert self.scheduler_type == 'ddpm', "released FGD implementation only supports DDPM"
182 |
183 | prev_t = self.scheduler.previous_timestep(t)
184 | alpha_prod_t = self.scheduler.alphas_cumprod[t]
185 | alpha_prod_t_prev = self.scheduler.alphas_cumprod[prev_t] if prev_t >= 0 else self.scheduler.one
186 | beta_prod_t = 1 - alpha_prod_t
187 | beta_prod_t_prev = 1 - alpha_prod_t_prev
188 | current_alpha_t = alpha_prod_t / alpha_prod_t_prev
189 | current_beta_t = 1 - current_alpha_t
190 |
191 | pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t
192 | current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t
193 |
194 | pred_prev_sample = pred_original_sample_coeff * st_filtered + current_sample_coeff * xt
195 | variance_noise = randn_tensor(
196 | noise_pred.shape, device=self.device, dtype=noise_pred.dtype
197 | )
198 |
199 | variance = (self.scheduler._get_variance(t) ** 0.5) * variance_noise
200 | xt_filtered = pred_prev_sample + variance
201 |
202 | return xt_filtered
203 |
204 | def decode_latents(self, latents):
205 | with torch.no_grad():
206 | image = self.vae.decode(latents*1/self.vae.config.scaling_factor).sample
207 |
208 | image = (image / 2 + 0.5).clamp(0, 1)
209 | image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
210 | images = (image * 255).round().astype("uint8")
211 | pil_images = [Image.fromarray(image) for image in images]
212 | if latents.shape[0] == 1:
213 | return pil_images[0]
214 | else:
215 | return pil_images
216 |
217 | def encode_image(self, image):
218 | if isinstance(image, str):
219 | guide_image = load_image(image)
220 | print()
221 | else:
222 | guide_image = image
223 | processor = VaeImageProcessor(self.vae.config)
224 | guide_processed = processor.preprocess(guide_image,self.height,self.width).to(self.device)
225 | with torch.no_grad():
226 | guide_latent = self.vae.encode(guide_processed).latent_dist.sample()*self.vae.config.scaling_factor
227 | return guide_latent
228 |
229 | def get_params(self):
230 | params = {
231 | 'prompt':self.prompt,
232 | 'version':self.version,
233 | 'use_ema':self.use_ema,
234 | 'scheduler':self.scheduler_type,
235 | 'initialization':self.latent_initialization,
236 | }
237 | return params
238 | def __str__(self):
239 | return (json.dumps(self.get_params(), indent=2))
--------------------------------------------------------------------------------
/figures/FGDTeaser.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FilteredGuidedDiffusion/505b260b5de437779a436d4ad3c0670a5bc806c3/figures/FGDTeaser.jpg
--------------------------------------------------------------------------------
/figures/FGDThumbnailYTLarge.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FilteredGuidedDiffusion/505b260b5de437779a436d4ad3c0670a5bc806c3/figures/FGDThumbnailYTLarge.jpg
--------------------------------------------------------------------------------
/figures/teaser_updated.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FilteredGuidedDiffusion/505b260b5de437779a436d4ad3c0670a5bc806c3/figures/teaser_updated.png
--------------------------------------------------------------------------------
/imgs/red_hat.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FilteredGuidedDiffusion/505b260b5de437779a436d4ad3c0670a5bc806c3/imgs/red_hat.jpg
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | diffusers==0.30.0
2 | numpy==2.0.1
3 | Pillow==10.4.0
4 | pytorch_lightning==2.4.0
5 | taichi==1.7.1
6 | torch==2.4.0+cu118
7 | tqdm==4.66.5
8 | transformers==4.44.0
9 |
--------------------------------------------------------------------------------