├── .gitignore ├── FGD.py ├── README.md ├── cbilateral.py ├── demo.ipynb ├── diffusionModel.py ├── figures ├── FGDTeaser.jpg ├── FGDThumbnailYTLarge.jpg └── teaser_updated.png ├── imgs └── red_hat.jpg └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | */.ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | # For a library or package, you might want to ignore these files since the code is 88 | # intended to run in multiple environments; otherwise, check them in: 89 | # .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # poetry 99 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 100 | # This is especially recommended for binary packages to ensure reproducibility, and is more 101 | # commonly ignored for libraries. 102 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 103 | #poetry.lock 104 | 105 | # pdm 106 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 107 | #pdm.lock 108 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 109 | # in version control. 110 | # https://pdm.fming.dev/#use-with-ide 111 | .pdm.toml 112 | 113 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 114 | __pypackages__/ 115 | 116 | # Celery stuff 117 | celerybeat-schedule 118 | celerybeat.pid 119 | 120 | # SageMath parsed files 121 | *.sage.py 122 | 123 | # Environments 124 | .env 125 | .venv 126 | env/ 127 | venv/ 128 | ENV/ 129 | env.bak/ 130 | venv.bak/ 131 | 132 | # Spyder project settings 133 | .spyderproject 134 | .spyproject 135 | 136 | # Rope project settings 137 | .ropeproject 138 | 139 | # mkdocs documentation 140 | /site 141 | 142 | # mypy 143 | .mypy_cache/ 144 | .dmypy.json 145 | dmypy.json 146 | 147 | # Pyre type checker 148 | .pyre/ 149 | 150 | # pytype static type analyzer 151 | .pytype/ 152 | 153 | # Cython debug symbols 154 | cython_debug/ 155 | 156 | # PyCharm 157 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 158 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 159 | # and can be added to the global gitignore or merged into this file. For a more nuclear 160 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 161 | #.idea/ 162 | 163 | #pytorch 164 | *.ckpt 165 | 166 | # misc 167 | .ipynb_checkpoints 168 | __pycache__ 169 | FGD_release -------------------------------------------------------------------------------- /FGD.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from cbilateral import getCrossBilateralMatrix4D 3 | import json 4 | import numpy as np 5 | 6 | device = "cuda" if torch.cuda.is_available() else "cpu" 7 | class FGD(): 8 | def __init__(self, diffusionModel, guide_image, detail=1.2, sigmas=[3,3,0.3], t_end=15, norm_steps=0): 9 | self.guide_image = guide_image 10 | self.detail = detail 11 | self.t_end = t_end 12 | self.sigmas = sigmas 13 | self.norm_steps = norm_steps 14 | self.model = diffusionModel 15 | self.bilateral_matrix_4d = None 16 | 17 | self.guide_latent = None 18 | self.guide_structure = None 19 | self.guide_structure_normalized = None 20 | 21 | self.init_guide_latent = None 22 | self.init_guide_structure = None 23 | self.init_guide_stucture_normalized=None 24 | 25 | self.init_bilateral_matrix_4d=None 26 | self.set_guide_image(guide_image) 27 | 28 | def set_ST(self, detail=1.6, recompute_matrix=True, sigmas=[3,3,0.3]): 29 | if recompute_matrix: 30 | self.set_bilateral_matrix(sigmas) 31 | self.detail = detail 32 | self.t_end = 15 33 | self.norm_steps = 50 34 | 35 | 36 | def reset(self): 37 | self.init_guide_latent = self.guide_latent 38 | self.guide_structure = self.init_guide_structure 39 | self.guide_structure_normalized = self.init_guide_structure_normalized 40 | self.bilateral_matrix_4d = self.init_bilateral_matrix_4d 41 | 42 | def set_guide_image(self, guide_image): 43 | self.guide_latent = self.model.encode_image(guide_image) 44 | self.guide_image = guide_image 45 | if self.sigmas != None: 46 | self.set_bilateral_matrix(self.sigmas) 47 | 48 | def set_bilateral_matrix(self,sigmas): 49 | assert len(sigmas)==2 or len(sigmas)==3, "sigmas has invalid number of entries (either 2 or 3)" 50 | sigmas = np.array(sigmas).astype(np.double) 51 | if len(sigmas) == 2: 52 | sigmas = np.insert(sigmas, 1, sigmas[0]) 53 | 54 | guide_latent_processed = self.guide_latent.detach().cpu().permute(0, 2, 3, 1).numpy() 55 | guide_latent_processed = np.squeeze(guide_latent_processed) 56 | bilateral_matrix = getCrossBilateralMatrix4D(guide_latent_processed.astype('double'),sigmas) 57 | self.bilateral_matrix_4d = torch.Tensor(bilateral_matrix).unsqueeze(0).repeat((4,1,1)).to(device) 58 | guide_structure_latent = torch.matmul(self.bilateral_matrix_4d, self.guide_latent.reshape(4,4096,1)) 59 | guide_structure_latent = guide_structure_latent.reshape(1,4,64,64) 60 | 61 | guide_mean = torch.mean(guide_structure_latent, (2,3), keepdim=True) 62 | guide_std = torch.std(guide_structure_latent, (2,3), keepdim=True) 63 | 64 | self.guide_structure_normalized = (guide_structure_latent - guide_mean) / guide_std 65 | self.guide_structure = guide_structure_latent 66 | 67 | self.init_guide_structure = self.guide_structure 68 | self.init_guide_structure_normalized=self.guide_structure_normalized 69 | self.init_bilateral_matrix_4d = self.bilateral_matrix_4d 70 | 71 | self.sigmas = sigmas.tolist() 72 | 73 | def get_residual_structure(self, latents): 74 | current_structure = torch.matmul(self.bilateral_matrix_4d, latents.reshape(4,4096,1)) 75 | current_structure = current_structure.reshape(1,4,64,64) 76 | 77 | d_structure = self.guide_structure - current_structure 78 | return d_structure 79 | 80 | def get_structure(self, latents, bm_4d=None): 81 | if bm_4d ==None: 82 | bm_4d = self.bilateral_matrix_4d 83 | structure = torch.matmul(bm_4d, latents.reshape(4,4096,1)) 84 | structure = structure.reshape(1,4,64,64) 85 | return structure 86 | 87 | def get_guidance(self, latents, input_latents, scheduler, t): 88 | guide_low = self.guide_structure 89 | 90 | st_low = self.get_structure(latents) 91 | st_high = latents - st_low 92 | 93 | weight= self.detail 94 | 95 | d = guide_low - st_low 96 | 97 | return weight, d 98 | 99 | 100 | def get_guidance_normalized(self, latents, input_latents, scheduler, t): 101 | current_structure = self.get_structure(latents) 102 | guide_structure = self.guide_structure 103 | 104 | current_mean = torch.mean(current_structure, (2,3), keepdim=True) 105 | current_std = torch.std(current_structure, (2,3), keepdim=True) 106 | 107 | guide_structure_renormalized = self.guide_structure_normalized * current_std + current_mean 108 | d_structure_renormalized = guide_structure_renormalized - current_structure 109 | 110 | residual_score = torch.mean(torch.abs(d_structure_renormalized)) 111 | 112 | weight = self.detail 113 | 114 | return weight, d_structure_renormalized 115 | 116 | def get_params(self): 117 | params = { 118 | 'guide image':self.guide_image, 119 | 'detail':self.detail, 120 | 'sigmas':self.sigmas, 121 | 't_end':self.t_end, 122 | 'norm steps':self.norm_steps, 123 | } 124 | return params 125 | def __str__(self): 126 | return (json.dumps(self.get_params(), indent=2)) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Filtered-Guided Diffusion for Controllable Image Generation 2 | *Fast, lightweight, architecture-independent, low-level frequency control for diffusion-based Image-to-Image translations.* 3 | 4 | [**Filtered-Guided Diffusion for Controllable Image Generation**](http://filterguideddiffusion.github.io/)
5 | [Zeqi Gu*](https://github.com/jaclyngu), 6 | [Ethan Yang*](https://www.cs.cornell.edu/abe/group/members), 7 | [Abe Davis](http://abedavis.com/)
8 | \* denotes equal Contribution
9 | _[GitHub](https://github.com/jaclyngu/FilteredGuidedDiffusion) | [Paper](https://dl.acm.org/doi/10.1145/3641519.3657489) | [Project Page](http://filterguideddiffusion.github.io)_ 10 | 11 | 12 | ![Teaser](./figures/teaser_updated.png) 13 | 14 | 17 | 18 | ## Code (now released!) 19 | ### Summary 20 | We provide a lightweight implementation of FGD which contains all the core functionality described in our paper. Our code is based on the [🤗 diffusers library](https://huggingface.co/docs/diffusers/en/index) and [taichi lang](https://www.taichi-lang.org/) for efficient computation of the cross bilateral matrix. 21 | 22 | A full explanation of how to use our code is described in the jupyter notebook demo.ipynb. 23 | 24 | For questions regarding the code, or access to a more comprehensive set of fuctions (although much less user friendly) including experimental features, debugging, and evaluation, please contact both authors at zg45@cornell.edu and eey8@cornell.edu. 25 | 26 | ### Setup 27 | We provide a requirements.txt file which contains the packages our implementation of FGD was tested on. Note running our code requires a GPU. 28 | ``` 29 | diffusers==0.30.0 30 | numpy==2.0.1 31 | Pillow==10.4.0 32 | pytorch_lightning==2.4.0 33 | taichi==1.7.1 34 | torch==2.4.0+cu118 35 | tqdm==4.66.5 36 | transformers==4.44.0 37 | ``` 38 | **Note:** jupyter notebook is also required in order to run our demo as we do not provide a command line interface. 39 | 40 | ## Citation 41 | For those wishing to use our work, please use the following citation: 42 | ``` 43 | @inproceedings{gu2024filter, 44 | title={Filter-Guided Diffusion for Controllable Image Generation}, 45 | author={Gu, Zeqi and Yang, Ethan and Davis, Abe}, 46 | booktitle={ACM SIGGRAPH 2024 Conference Papers}, 47 | pages={1--10}, 48 | year={2024} 49 | } 50 | ``` -------------------------------------------------------------------------------- /cbilateral.py: -------------------------------------------------------------------------------- 1 | import taichi as ti 2 | import numpy as np 3 | 4 | ti.init(arch=ti.gpu, default_fp=ti.f64) 5 | EPSILON = 1e-5 6 | 7 | @ti.func 8 | def nGaussExp(x: ti.f64, sigma: ti.f64): 9 | # Special case if sigma is 0. In this case we should have a dirac delta. 10 | value = 0.0 11 | if(sigma filter.t_end: 177 | st_filtered += d*weight 178 | else: 179 | st_filtered = st 180 | 181 | assert self.scheduler_type == 'ddpm', "released FGD implementation only supports DDPM" 182 | 183 | prev_t = self.scheduler.previous_timestep(t) 184 | alpha_prod_t = self.scheduler.alphas_cumprod[t] 185 | alpha_prod_t_prev = self.scheduler.alphas_cumprod[prev_t] if prev_t >= 0 else self.scheduler.one 186 | beta_prod_t = 1 - alpha_prod_t 187 | beta_prod_t_prev = 1 - alpha_prod_t_prev 188 | current_alpha_t = alpha_prod_t / alpha_prod_t_prev 189 | current_beta_t = 1 - current_alpha_t 190 | 191 | pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t 192 | current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t 193 | 194 | pred_prev_sample = pred_original_sample_coeff * st_filtered + current_sample_coeff * xt 195 | variance_noise = randn_tensor( 196 | noise_pred.shape, device=self.device, dtype=noise_pred.dtype 197 | ) 198 | 199 | variance = (self.scheduler._get_variance(t) ** 0.5) * variance_noise 200 | xt_filtered = pred_prev_sample + variance 201 | 202 | return xt_filtered 203 | 204 | def decode_latents(self, latents): 205 | with torch.no_grad(): 206 | image = self.vae.decode(latents*1/self.vae.config.scaling_factor).sample 207 | 208 | image = (image / 2 + 0.5).clamp(0, 1) 209 | image = image.detach().cpu().permute(0, 2, 3, 1).numpy() 210 | images = (image * 255).round().astype("uint8") 211 | pil_images = [Image.fromarray(image) for image in images] 212 | if latents.shape[0] == 1: 213 | return pil_images[0] 214 | else: 215 | return pil_images 216 | 217 | def encode_image(self, image): 218 | if isinstance(image, str): 219 | guide_image = load_image(image) 220 | print() 221 | else: 222 | guide_image = image 223 | processor = VaeImageProcessor(self.vae.config) 224 | guide_processed = processor.preprocess(guide_image,self.height,self.width).to(self.device) 225 | with torch.no_grad(): 226 | guide_latent = self.vae.encode(guide_processed).latent_dist.sample()*self.vae.config.scaling_factor 227 | return guide_latent 228 | 229 | def get_params(self): 230 | params = { 231 | 'prompt':self.prompt, 232 | 'version':self.version, 233 | 'use_ema':self.use_ema, 234 | 'scheduler':self.scheduler_type, 235 | 'initialization':self.latent_initialization, 236 | } 237 | return params 238 | def __str__(self): 239 | return (json.dumps(self.get_params(), indent=2)) -------------------------------------------------------------------------------- /figures/FGDTeaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FilteredGuidedDiffusion/505b260b5de437779a436d4ad3c0670a5bc806c3/figures/FGDTeaser.jpg -------------------------------------------------------------------------------- /figures/FGDThumbnailYTLarge.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FilteredGuidedDiffusion/505b260b5de437779a436d4ad3c0670a5bc806c3/figures/FGDThumbnailYTLarge.jpg -------------------------------------------------------------------------------- /figures/teaser_updated.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FilteredGuidedDiffusion/505b260b5de437779a436d4ad3c0670a5bc806c3/figures/teaser_updated.png -------------------------------------------------------------------------------- /imgs/red_hat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FilteredGuidedDiffusion/505b260b5de437779a436d4ad3c0670a5bc806c3/imgs/red_hat.jpg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | diffusers==0.30.0 2 | numpy==2.0.1 3 | Pillow==10.4.0 4 | pytorch_lightning==2.4.0 5 | taichi==1.7.1 6 | torch==2.4.0+cu118 7 | tqdm==4.66.5 8 | transformers==4.44.0 9 | --------------------------------------------------------------------------------