├── lib ├── diffc │ ├── rcc │ │ ├── arithmetic-coding │ │ │ ├── metrics.py │ │ │ ├── .gitignore │ │ │ ├── arithmetic-coding-core │ │ │ │ ├── README.md │ │ │ │ ├── src │ │ │ │ │ ├── lib.rs │ │ │ │ │ ├── bitstore.rs │ │ │ │ │ ├── model.rs │ │ │ │ │ └── model │ │ │ │ │ │ ├── one_shot.rs │ │ │ │ │ │ ├── max_length.rs │ │ │ │ │ │ └── fixed_length.rs │ │ │ │ └── Cargo.toml │ │ │ ├── python-bindings │ │ │ │ ├── zipf_encoding │ │ │ │ │ └── __init__.py │ │ │ │ ├── pyproject.toml │ │ │ │ ├── Cargo.toml │ │ │ │ ├── setup.py │ │ │ │ └── src │ │ │ │ │ ├── lib.rs │ │ │ │ │ └── zipf.rs │ │ │ ├── CONTRIBUTING.md │ │ │ ├── src │ │ │ │ ├── lib.rs │ │ │ │ ├── encoder.rs │ │ │ │ └── decoder.rs │ │ │ ├── Cargo.toml │ │ │ └── README.md │ │ ├── cuda_kernels.cu │ │ ├── pfr.py │ │ ├── chunk_coding.py │ │ └── gaussian_channel_simulator.py │ ├── utils │ │ ├── alpha_beta.py │ │ ├── q.py │ │ └── p.py │ ├── denoise.py │ ├── decode.py │ └── encode.py ├── models │ ├── SD15.py │ ├── latent_noise_prediction_model.py │ ├── SD21.py │ ├── SD.py │ ├── Flux.py │ └── SDXL.py ├── image_utils.py ├── blip.py └── metrics.py ├── data ├── div2k │ ├── 0001.png │ └── 0016.png └── kodak │ ├── 01.png │ ├── 02.png │ ├── 03.png │ ├── 04.png │ ├── 05.png │ ├── 06.png │ ├── 07.png │ ├── 08.png │ ├── 09.png │ ├── 10.png │ ├── 11.png │ ├── 12.png │ ├── 13.png │ ├── 14.png │ ├── 15.png │ ├── 16.png │ ├── 17.png │ ├── 18.png │ ├── 19.png │ ├── 20.png │ ├── 21.png │ ├── 22.png │ ├── 23.png │ └── 24.png ├── figures ├── visual-comparison.png ├── kodak-rd-curves-Qalign.png └── div2k-1024-rd-curves-Qalign.png ├── .gitignore ├── environment.yml ├── LICENSE ├── configs ├── SD-1.5-prompt.yaml ├── SD-1.5-base.yaml ├── SD-2.1-base.yaml ├── SDXL-base.yaml └── Flux-base.yaml ├── readme.md ├── notebooks └── blip-prompts.ipynb ├── decompress.py ├── compress.py └── evaluate.py /lib/diffc/rcc/arithmetic-coding/metrics.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /lib/diffc/rcc/arithmetic-coding/.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | /Cargo.lock 3 | -------------------------------------------------------------------------------- /data/div2k/0001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JeremyIV/diffc/HEAD/data/div2k/0001.png -------------------------------------------------------------------------------- /data/div2k/0016.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JeremyIV/diffc/HEAD/data/div2k/0016.png -------------------------------------------------------------------------------- /data/kodak/01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JeremyIV/diffc/HEAD/data/kodak/01.png -------------------------------------------------------------------------------- /data/kodak/02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JeremyIV/diffc/HEAD/data/kodak/02.png -------------------------------------------------------------------------------- /data/kodak/03.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JeremyIV/diffc/HEAD/data/kodak/03.png -------------------------------------------------------------------------------- /data/kodak/04.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JeremyIV/diffc/HEAD/data/kodak/04.png -------------------------------------------------------------------------------- /data/kodak/05.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JeremyIV/diffc/HEAD/data/kodak/05.png -------------------------------------------------------------------------------- /data/kodak/06.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JeremyIV/diffc/HEAD/data/kodak/06.png -------------------------------------------------------------------------------- /data/kodak/07.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JeremyIV/diffc/HEAD/data/kodak/07.png -------------------------------------------------------------------------------- /data/kodak/08.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JeremyIV/diffc/HEAD/data/kodak/08.png -------------------------------------------------------------------------------- /data/kodak/09.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JeremyIV/diffc/HEAD/data/kodak/09.png -------------------------------------------------------------------------------- /data/kodak/10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JeremyIV/diffc/HEAD/data/kodak/10.png -------------------------------------------------------------------------------- /data/kodak/11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JeremyIV/diffc/HEAD/data/kodak/11.png -------------------------------------------------------------------------------- /data/kodak/12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JeremyIV/diffc/HEAD/data/kodak/12.png -------------------------------------------------------------------------------- /data/kodak/13.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JeremyIV/diffc/HEAD/data/kodak/13.png -------------------------------------------------------------------------------- /data/kodak/14.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JeremyIV/diffc/HEAD/data/kodak/14.png -------------------------------------------------------------------------------- /data/kodak/15.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JeremyIV/diffc/HEAD/data/kodak/15.png -------------------------------------------------------------------------------- /data/kodak/16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JeremyIV/diffc/HEAD/data/kodak/16.png -------------------------------------------------------------------------------- /data/kodak/17.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JeremyIV/diffc/HEAD/data/kodak/17.png -------------------------------------------------------------------------------- /data/kodak/18.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JeremyIV/diffc/HEAD/data/kodak/18.png -------------------------------------------------------------------------------- /data/kodak/19.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JeremyIV/diffc/HEAD/data/kodak/19.png -------------------------------------------------------------------------------- /data/kodak/20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JeremyIV/diffc/HEAD/data/kodak/20.png -------------------------------------------------------------------------------- /data/kodak/21.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JeremyIV/diffc/HEAD/data/kodak/21.png -------------------------------------------------------------------------------- /data/kodak/22.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JeremyIV/diffc/HEAD/data/kodak/22.png -------------------------------------------------------------------------------- /data/kodak/23.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JeremyIV/diffc/HEAD/data/kodak/23.png -------------------------------------------------------------------------------- /data/kodak/24.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JeremyIV/diffc/HEAD/data/kodak/24.png -------------------------------------------------------------------------------- /figures/visual-comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JeremyIV/diffc/HEAD/figures/visual-comparison.png -------------------------------------------------------------------------------- /figures/kodak-rd-curves-Qalign.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JeremyIV/diffc/HEAD/figures/kodak-rd-curves-Qalign.png -------------------------------------------------------------------------------- /figures/div2k-1024-rd-curves-Qalign.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JeremyIV/diffc/HEAD/figures/div2k-1024-rd-curves-Qalign.png -------------------------------------------------------------------------------- /lib/diffc/rcc/arithmetic-coding/arithmetic-coding-core/README.md: -------------------------------------------------------------------------------- 1 | # Arithmetic Coding Core 2 | 3 | core traits for the `arithmetic-coding` crate -------------------------------------------------------------------------------- /lib/diffc/rcc/arithmetic-coding/python-bindings/zipf_encoding/__init__.py: -------------------------------------------------------------------------------- 1 | from .zipf_encoding import encode_zipf, decode_zipf 2 | 3 | __all__ = ["encode_zipf", "decode_zipf"] 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | env/ 2 | # Python cache 3 | __pycache__/ 4 | *.pyc 5 | *.so 6 | 7 | # Jupyter 8 | .ipynb_checkpoints/ 9 | *.ipynb 10 | 11 | # PyTorch models 12 | *.pt 13 | 14 | # Documentation 15 | .github/ 16 | .rustfmt.toml 17 | 18 | # Results and tests 19 | results/ 20 | test_run.sh -------------------------------------------------------------------------------- /lib/diffc/utils/alpha_beta.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def get_alpha_prod_and_beta_prod(snr): 5 | if snr == torch.inf: 6 | alpha_prod = 1 7 | else: 8 | alpha_prod = snr ** 2 / (1 + snr ** 2) 9 | beta_prod = 1 - alpha_prod 10 | return alpha_prod, beta_prod 11 | -------------------------------------------------------------------------------- /lib/diffc/rcc/arithmetic-coding/arithmetic-coding-core/src/lib.rs: -------------------------------------------------------------------------------- 1 | //! Core traits for the [`arithmetic-coding`](https://github.com/danieleades/arithmetic-coding) crate 2 | 3 | #![deny(missing_docs, missing_debug_implementations)] 4 | //#![feature(associated_type_defaults)] 5 | 6 | mod bitstore; 7 | pub use bitstore::BitStore; 8 | 9 | mod model; 10 | pub use model::{fixed_length, max_length, one_shot, Model}; 11 | -------------------------------------------------------------------------------- /lib/diffc/rcc/arithmetic-coding/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | ## Run the tests 4 | 5 | cargo test 6 | 7 | ## Run the bench tests 8 | 9 | cargo bench 10 | 11 | ## Run the fuzz tests 12 | 13 | (requires [cargo-fuzz](https://github.com/rust-fuzz/cargo-fuzz)) 14 | 15 | cargo fuzz run fuzz_target_1 16 | 17 | ## Run the examples 18 | 19 | cargo run --example=${EXAMPLE} 20 | -------------------------------------------------------------------------------- /lib/diffc/rcc/arithmetic-coding/python-bindings/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["maturin>=1.0,<2.0"] 3 | build-backend = "maturin" 4 | 5 | [project] 6 | name = "zipf_encoding" 7 | requires-python = ">=3.7" 8 | classifiers = [ 9 | "Programming Language :: Rust", 10 | "Programming Language :: Python :: Implementation :: CPython", 11 | "Programming Language :: Python :: Implementation :: PyPy", 12 | ] 13 | 14 | [tool.maturin] 15 | features = ["extension-module"] -------------------------------------------------------------------------------- /lib/diffc/rcc/arithmetic-coding/python-bindings/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "zipf_encoding" 3 | version = "0.1.0" 4 | edition = "2021" 5 | 6 | [lib] 7 | name = "zipf_encoding" 8 | crate-type = ["cdylib"] 9 | 10 | [dependencies] 11 | pyo3 = { version = "0.18.3", features = ["extension-module"] } 12 | arithmetic-coding = { path = ".." } 13 | bitstream-io = "2.0.0" 14 | thiserror = "1.0.30" 15 | 16 | [features] 17 | extension-module = ["pyo3/extension-module"] 18 | default = ["extension-module"] -------------------------------------------------------------------------------- /lib/diffc/rcc/arithmetic-coding/python-bindings/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from setuptools_rust import Binding, RustExtension 3 | 4 | setup( 5 | name="zipf_encoding", 6 | version="0.1.0", 7 | rust_extensions=[ 8 | RustExtension("zipf_encoding.zipf_encoding", binding=Binding.PyO3) 9 | ], 10 | packages=["zipf_encoding"], 11 | # zip_safe flag can be set to False if you want to load the extension from inside the ZIP file 12 | zip_safe=False, 13 | ) 14 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: diffc 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - conda-forge 6 | dependencies: 7 | - python=3.8 8 | - pytorch 9 | - torchvision 10 | - torchaudio 11 | - cupy 12 | - rust 13 | - pip 14 | - pip: 15 | - diffusers 16 | - transformers 17 | - easydict 18 | - numpy 19 | - pandas 20 | - Pillow 21 | - PyYAML 22 | - setuptools 23 | - setuptools_rust 24 | - scikit-image 25 | - tqdm 26 | - lpips 27 | - -e lib/diffc/rcc/arithmetic-coding/python-bindings 28 | -------------------------------------------------------------------------------- /lib/diffc/rcc/arithmetic-coding/src/lib.rs: -------------------------------------------------------------------------------- 1 | //! Arithmetic coding library 2 | 3 | #![deny(missing_docs, missing_debug_implementations)] 4 | 5 | pub use arithmetic_coding_core::{fixed_length, max_length, one_shot, BitStore, Model}; 6 | 7 | pub mod decoder; 8 | pub mod encoder; 9 | 10 | pub use decoder::Decoder; 11 | pub use encoder::Encoder; 12 | 13 | /// Errors that can occur during encoding/decoding 14 | #[derive(Debug, thiserror::Error)] 15 | pub enum Error { 16 | /// Io error when reading/writing bits from a stream 17 | Io(#[from] std::io::Error), 18 | 19 | /// Invalid symbol 20 | ValueError(E), 21 | } 22 | -------------------------------------------------------------------------------- /lib/diffc/rcc/arithmetic-coding/arithmetic-coding-core/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "arithmetic-coding-core" 3 | description = "core traits for the 'arithmetic-coding' crate" 4 | version = "0.3.0" 5 | edition = "2021" 6 | license = "MIT" 7 | keywords = ["compression", "encoding", "arithmetic-coding", "lossless"] 8 | categories = ["compression", "encoding", "parsing"] 9 | repository = "https://github.com/danieleades/arithmetic-coding" 10 | 11 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 12 | 13 | [dependencies] 14 | thiserror = { workspace = true } 15 | 16 | [lints] 17 | workspace = true 18 | -------------------------------------------------------------------------------- /lib/models/SD15.py: -------------------------------------------------------------------------------- 1 | from lib.models.SD import SDModel 2 | import torch 3 | 4 | 5 | class SD15Model(SDModel): 6 | def __init__(self, device="cuda", dtype=torch.float16): 7 | super().__init__( 8 | model_id="runwayml/stable-diffusion-v1-5", device=device, dtype=dtype 9 | ) 10 | 11 | def _get_noise_pred(self, latent_model_input, timestep, encoder_hidden_states): 12 | """ 13 | Get noise prediction from SD 1.5 UNet (which directly outputs epsilon/noise). 14 | """ 15 | return self.unet( 16 | latent_model_input, timestep, encoder_hidden_states=encoder_hidden_states 17 | ).sample 18 | -------------------------------------------------------------------------------- /lib/image_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def np_to_torch_img(img_np): 6 | img_pt = torch.tensor(img_np.astype("float") / 255) 7 | img_pt = img_pt.permute(2, 0, 1).unsqueeze(0).half().to("cuda") 8 | return img_pt 9 | 10 | 11 | def pil_to_torch_img(img_pil): 12 | return np_to_torch_img(np.array(img_pil)) 13 | 14 | 15 | def torch_to_np_img(img): 16 | return img[0].permute(1, 2, 0).clip(0, 1).detach().cpu().numpy() 17 | 18 | 19 | def np_to_pil_img(img): 20 | from PIL import Image 21 | 22 | return Image.fromarray((img * 255).astype("uint8")) 23 | 24 | 25 | def torch_to_pil_img(img): 26 | return np_to_pil_img(torch_to_np_img(img)) 27 | -------------------------------------------------------------------------------- /lib/diffc/utils/q.py: -------------------------------------------------------------------------------- 1 | from lib.diffc.utils.alpha_beta import get_alpha_prod_and_beta_prod 2 | 3 | 4 | def Q(noisy_latent, target_latent, current_snr, prev_snr): 5 | alpha_prod_t, beta_prod_t = get_alpha_prod_and_beta_prod(current_snr) 6 | alpha_prod_t_prev, beta_prod_t_prev = get_alpha_prod_and_beta_prod(prev_snr) 7 | current_alpha_t = alpha_prod_t / alpha_prod_t_prev 8 | current_beta_t = 1 - current_alpha_t 9 | 10 | pred_original_sample_coeff = ( 11 | alpha_prod_t_prev ** (0.5) * current_beta_t 12 | ) / beta_prod_t 13 | current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t 14 | 15 | mu = ( 16 | pred_original_sample_coeff * target_latent + current_sample_coeff * noisy_latent 17 | ) 18 | 19 | return mu 20 | -------------------------------------------------------------------------------- /lib/diffc/rcc/arithmetic-coding/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "arithmetic-coding" 3 | description = "fast and flexible arithmetic coding library" 4 | version = "0.3.1" 5 | edition = "2021" 6 | license = "MIT" 7 | keywords = ["compression", "encoding", "arithmetic-coding", "lossless"] 8 | categories = ["compression", "encoding", "parsing"] 9 | repository = "https://github.com/danieleades/arithmetic-coding" 10 | 11 | [workspace] 12 | members = [".", "arithmetic-coding-core", "python-bindings"] 13 | 14 | [workspace.dependencies] 15 | thiserror = "1.0.30" 16 | 17 | [workspace.lints.clippy] 18 | cargo = "deny" 19 | all = "deny" 20 | nursery = "warn" 21 | pedantic = "warn" 22 | 23 | [dependencies] 24 | arithmetic-coding-core = { path = "./arithmetic-coding-core", version = "0.3.0" } 25 | bitstream-io = "2.0.0" 26 | thiserror = { workspace = true } 27 | 28 | [dev-dependencies] 29 | criterion = "0.5.1" 30 | test-case = "3.0.0" 31 | 32 | [lints] 33 | workspace = true 34 | -------------------------------------------------------------------------------- /lib/models/latent_noise_prediction_model.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | 4 | class LatentNoisePredictionModel(ABC): 5 | @abstractmethod 6 | def image_to_latent(self, img_pt): 7 | """Convert image to latent representation.""" 8 | pass 9 | 10 | @abstractmethod 11 | def latent_to_image(self, latent): 12 | """Convert latent representation to image.""" 13 | pass 14 | 15 | @abstractmethod 16 | def configure(self, prompt, prompt_guidance, image_width, image_height): 17 | """Configure the model with given parameters.""" 18 | pass 19 | 20 | @abstractmethod 21 | def get_timestep_snr(self, timestep): 22 | """Return the signal to noise ratio (snr) that the model expects at this timestep.""" 23 | pass 24 | 25 | @abstractmethod 26 | def predict_noise(self, noisy_latent, timestep): 27 | """Predict noise for given latent at timestep.""" 28 | pass 29 | -------------------------------------------------------------------------------- /lib/models/SD21.py: -------------------------------------------------------------------------------- 1 | from lib.models.SD import SDModel 2 | import torch 3 | 4 | 5 | class SD21Model(SDModel): 6 | def __init__(self, device="cuda", dtype=torch.float16): 7 | super().__init__( 8 | model_id="stabilityai/stable-diffusion-2-1", device=device, dtype=dtype 9 | ) 10 | 11 | def _get_noise_pred(self, latent_model_input, timestep, encoder_hidden_states): 12 | """ 13 | Get noise prediction from SD 2.1 UNet (converting v-prediction to epsilon). 14 | """ 15 | # Get v-prediction from model 16 | v_prediction = self.unet( 17 | latent_model_input, timestep, encoder_hidden_states=encoder_hidden_states 18 | ).sample 19 | 20 | # Get alpha and beta values for current timestep 21 | alpha_prod_t = self.reference_scheduler.alphas_cumprod[timestep - 1] 22 | beta_prod_t = 1 - alpha_prod_t 23 | 24 | # Convert v-prediction to epsilon (noise) prediction 25 | noise_pred = (alpha_prod_t ** 0.5) * v_prediction + ( 26 | beta_prod_t ** 0.5 27 | ) * latent_model_input 28 | 29 | return noise_pred 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Jeremy Vonderfecht and Feng Liu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /configs/SD-1.5-prompt.yaml: -------------------------------------------------------------------------------- 1 | max_chunk_size: 16 2 | chunk_padding: 2 3 | model: 'SD1.5' 4 | encoding_guidance_scale: 1 5 | denoising_guidance_scale: 5 6 | encoding_timesteps: [972, 949, 929, 897, 869, 834, 805, 780, 751, 726, 704, 688, 670, 648, 627, 608, 591, 578, 561, 546, 530, 520, 510, 498, 491, 480, 465, 455, 447, 438, 429, 419, 410, 402, 390, 380, 371, 361, 353, 345, 336, 326, 319, 313, 305, 296, 289, 282, 276, 269, 261, 254, 247, 242, 237, 231, 224, 219, 213, 209, 204, 200, 194, 189, 185, 181, 175, 170, 167, 163, 160, 156, 153, 149, 146, 143, 139, 135, 132, 129, 125, 121, 118, 116, 113, 110, 107, 104, 101, 99, 96, 94, 92, 90, 87, 85, 82, 80, 78, 76, 74, 72, 70, 68, 66, 64, 62, 60, 58, 56, 54, 52, 50, 48, 47, 46, 45, 44, 43, 42, 41, 40, 39, 38, 37, 36, 35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1] 7 | manual_dkl_per_step: null 8 | recon_timesteps: [900, 800, 700, 600, 500, 400, 300, 200, 100, 90, 80, 70, 60, 50, 40, 30, 20, 10] 9 | denoising_timesteps: [981, 961, 941, 921, 901, 881, 861, 841, 821, 801, 781, 761, 741, 721, 701, 681, 661, 641, 621, 601, 581, 561, 541, 521, 501, 481, 461, 441, 421, 401, 381, 361, 341, 321, 301, 281, 261, 241, 221, 201, 181, 161, 141, 121, 101, 81, 61, 41, 21, 10, 5, 0] 10 | -------------------------------------------------------------------------------- /lib/diffc/rcc/arithmetic-coding/arithmetic-coding-core/src/bitstore.rs: -------------------------------------------------------------------------------- 1 | use std::ops::{Add, AddAssign, Div, Mul, Shl, ShlAssign, Sub}; 2 | 3 | /// A trait for a type that can be used for the internal integer representation 4 | /// of an encoder or decoder 5 | pub trait BitStore: 6 | Shl 7 | + ShlAssign 8 | + Sub 9 | + Add 10 | + Mul 11 | + Div 12 | + AddAssign 13 | + PartialOrd 14 | + Copy 15 | + std::fmt::Debug 16 | { 17 | /// the number of bits needed to represent this type 18 | const BITS: u32; 19 | 20 | /// the additive identity 21 | const ZERO: Self; 22 | 23 | /// the multiplicative identity 24 | const ONE: Self; 25 | 26 | /// integer natural logarithm, rounded down 27 | fn log2(self) -> u32; 28 | } 29 | 30 | macro_rules! impl_bitstore { 31 | ($t:ty) => { 32 | impl BitStore for $t { 33 | const BITS: u32 = Self::BITS; 34 | const ONE: Self = 1; 35 | const ZERO: Self = 0; 36 | 37 | fn log2(self) -> u32 { 38 | Self::ilog2(self) 39 | } 40 | } 41 | }; 42 | } 43 | 44 | impl_bitstore! {u32} 45 | impl_bitstore! {u64} 46 | impl_bitstore! {u128} 47 | impl_bitstore! {usize} 48 | -------------------------------------------------------------------------------- /lib/diffc/utils/p.py: -------------------------------------------------------------------------------- 1 | from lib.diffc.utils.alpha_beta import get_alpha_prod_and_beta_prod 2 | import torch 3 | 4 | 5 | def P(noisy_latent, noise_prediction, current_snr, prev_snr): 6 | alpha_prod_t, beta_prod_t = get_alpha_prod_and_beta_prod(current_snr) 7 | alpha_prod_t_prev, beta_prod_t_prev = get_alpha_prod_and_beta_prod(prev_snr) 8 | current_alpha_t = alpha_prod_t / alpha_prod_t_prev 9 | current_beta_t = 1 - current_alpha_t 10 | 11 | pred_original_sample = ( 12 | noisy_latent - beta_prod_t ** (0.5) * noise_prediction 13 | ) / alpha_prod_t ** (0.5) 14 | 15 | pred_original_sample_coeff = ( 16 | alpha_prod_t_prev ** (0.5) * current_beta_t 17 | ) / beta_prod_t 18 | current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t 19 | 20 | # 5. Compute predicted previous sample µ_t 21 | # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf 22 | pred_prev_sample = ( 23 | pred_original_sample_coeff * pred_original_sample 24 | + current_sample_coeff * noisy_latent 25 | ) 26 | 27 | # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf) 28 | variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * current_beta_t 29 | std = variance ** (0.5) 30 | return pred_prev_sample, std 31 | -------------------------------------------------------------------------------- /lib/diffc/rcc/arithmetic-coding/README.md: -------------------------------------------------------------------------------- 1 | # Arithmetic Coding 2 | 3 | [![Latest Docs](https://docs.rs/arithmetic-coding/badge.svg)](https://docs.rs/arithmetic-coding/) 4 | ![Continuous integration](https://github.com/danieleades/arithmetic-coding/workflows/Continuous%20integration/badge.svg) 5 | [![codecov](https://codecov.io/gh/danieleades/arithmetic-coding/branch/main/graph/badge.svg?token=1qITX2tR0J)](https://codecov.io/gh/danieleades/arithmetic-coding) 6 | 7 | 8 | A symbolic [arithmetic coding](https://en.wikipedia.org/wiki/Arithmetic_coding) library. 9 | 10 | Extending this library is as simple as implementing the `Model` trait for your own type, and then plugging it in the provided `Encoder`/`Decoder`. Supports both fixed-length and variable-length encoding, as well as both adaptive and non-adaptive models. 11 | 12 | Take a look at the [API docs](https://docs.rs/arithmetic-coding/) or the [examples](https://github.com/danieleades/arithmetic-coding/tree/main/examples). 13 | 14 | This crate is heavily inspired by 15 | 16 | - [arcode-rs](https://github.com/cgburgess/arcode-rs) 17 | - [Data Compression With Arithmetic Coding - *Mark Nelson*, 2014](https://marknelson.us/posts/2014/10/19/data-compression-with-arithmetic-coding.html) 18 | 19 | *Was this useful? [Buy me a coffee](https://github.com/sponsors/danieleades/sponsorships?sponsor=danieleades&preview=true&frequency=recurring&amount=5)* -------------------------------------------------------------------------------- /lib/diffc/rcc/cuda_kernels.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | extern "C" { 4 | 5 | __global__ void generate_sample_kernel( 6 | int dim, 7 | unsigned long long shared_seed, 8 | unsigned long long idx, 9 | float* sample_out) { 10 | 11 | if (threadIdx.x == 0 && blockIdx.x == 0) { 12 | curandState state; 13 | curand_init(shared_seed, 0, idx * dim, &state); 14 | //curand_init(shared_seed + idx, 0, 0, &state); 15 | 16 | for (int i = 0; i < dim; i++) { 17 | sample_out[i] = curand_normal(&state); 18 | } 19 | } 20 | } 21 | 22 | __global__ void reverse_channel_encode_kernel( 23 | const float* mu_q, 24 | int dim, 25 | unsigned long long K, 26 | unsigned long long shared_seed, 27 | float* log_w 28 | ) { 29 | int idx = blockIdx.x * blockDim.x + threadIdx.x; 30 | if (idx >= K) return; 31 | 32 | curandState state; 33 | curand_init(shared_seed, 0, idx * dim, &state); 34 | 35 | //curand_init(shared_seed + idx, 0, 0, &state); 36 | 37 | float log_w_value = 0.0f; 38 | for (int i = 0; i < dim; i++) { 39 | float sample_value = curand_normal(&state); 40 | //log_w_value += 0.5 * (sample_value * sample_value - (sample_value - mu_q[i]) * (sample_value - mu_q[i])); 41 | log_w_value += sample_value*mu_q[i]; 42 | } 43 | 44 | log_w[idx] = log_w_value; 45 | } 46 | 47 | } // extern "C" -------------------------------------------------------------------------------- /lib/diffc/rcc/arithmetic-coding/python-bindings/src/lib.rs: -------------------------------------------------------------------------------- 1 | use pyo3::prelude::*; 2 | use arithmetic_coding::{Encoder, Decoder}; 3 | use bitstream_io::{BigEndian, BitReader, BitWrite, BitWriter}; 4 | 5 | mod zipf; 6 | use zipf::ZipfModel; 7 | 8 | #[pyfunction] 9 | fn encode_zipf(s_values: Vec, n_values: Vec, numbers: Vec) -> PyResult> { 10 | let model = ZipfModel::new(s_values, n_values); 11 | let mut bitwriter = BitWriter::endian(Vec::new(), BigEndian); 12 | let mut encoder = Encoder::new(model, &mut bitwriter); 13 | 14 | encoder.encode_all(numbers).map_err(|_e| PyErr::new::("Encoding error"))?; 15 | bitwriter.byte_align().map_err(|_e| PyErr::new::("Byte alignment error"))?; 16 | 17 | Ok(bitwriter.into_writer()) 18 | } 19 | 20 | #[pyfunction] 21 | fn decode_zipf(s_values: Vec, n_values: Vec, encoded: Vec) -> PyResult> { 22 | let model = ZipfModel::new(s_values, n_values); 23 | let bitreader = BitReader::endian(encoded.as_slice(), BigEndian); 24 | let mut decoder = Decoder::new(model, bitreader); 25 | 26 | decoder.decode_all() 27 | .collect::, _>>() 28 | .map_err(|_e| PyErr::new::("Decoding error")) 29 | } 30 | 31 | #[pymodule] 32 | fn zipf_encoding(_py: Python<'_>, m: &PyModule) -> PyResult<()> { 33 | m.add_function(wrap_pyfunction!(encode_zipf, m)?)?; 34 | m.add_function(wrap_pyfunction!(decode_zipf, m)?)?; 35 | Ok(()) 36 | } -------------------------------------------------------------------------------- /configs/SD-1.5-base.yaml: -------------------------------------------------------------------------------- 1 | max_chunk_size: 16 2 | chunk_padding: 2 3 | model: 'SD1.5' 4 | encoding_guidance_scale: 0 5 | denoising_guidance_scale: 0 6 | encoding_timesteps: [972, 949, 929, 897, 869, 834, 805, 780, 751, 726, 704, 688, 670, 648, 627, 608, 591, 578, 561, 546, 530, 520, 510, 498, 491, 480, 465, 455, 447, 438, 429, 419, 410, 402, 390, 380, 371, 361, 353, 345, 336, 326, 319, 313, 305, 296, 289, 282, 276, 269, 261, 254, 247, 242, 237, 231, 224, 219, 213, 209, 204, 200, 194, 189, 185, 181, 175, 170, 167, 163, 160, 156, 153, 149, 146, 143, 139, 135, 132, 129, 125, 121, 118, 116, 113, 110, 107, 104, 101, 99, 96, 94, 92, 90, 87, 85, 82, 80, 78, 76, 74, 72, 70, 68, 66, 64, 62, 60, 58, 56, 54, 52, 50, 48, 47, 46, 45, 44, 43, 42, 41, 40, 39, 38, 37, 36, 35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1] 7 | manual_dkl_per_step: null 8 | # manual_dkl_per_step: [20, 20, 20, 39, 42, 67, 67, 68, 92, 92, 92, 73, 90, 122, 130, 129, 125, 102, 143, 135, 154, 101, 106, 133, 81, 132, 191, 134, 111, 129, 134, 155, 145, 133, 210, 183, 171, 199, 165, 171, 200, 233, 169, 149, 206, 243, 196, 203, 179, 216, 257, 234, 243, 179, 184, 228, 277, 204, 254, 174, 224, 184, 286, 247, 203, 209, 326, 282, 173, 238, 183, 251, 193, 265, 204, 210, 290, 300, 231, 238, 329, 342, 265, 181, 279, 288, 298, 308, 319, 219, 340, 233, 239, 245, 382, 263, 409, 282, 291, 300, 309, 320, 330, 342, 354, 368, 382, 397, 414, 431, 450, 471, 493, 517, 267, 274, 281, 289, 296, 304, 313, 322, 332, 343, 354, 365, 378, 391, 405, 420, 436, 453, 472, 493, 514, 538, 563, 592, 622, 655, 692, 734, 779, 831, 888, 955, 1030, 1117, 1218, 1336, 1478, 1650, 1864, 2133, 2488, 2971, 3646, 4681, 6451, 10077, 21297] 9 | recon_timesteps: [900, 800, 700, 600, 500, 400, 300, 200, 100, 90, 80, 70, 60, 50, 40, 30, 20, 10] 10 | denoising_timesteps: [981, 961, 941, 921, 901, 881, 861, 841, 821, 801, 781, 761, 741, 721, 701, 681, 661, 641, 621, 601, 581, 561, 541, 521, 501, 481, 461, 441, 421, 401, 381, 361, 341, 321, 301, 281, 261, 241, 221, 201, 181, 161, 141, 121, 101, 81, 61, 41, 21, 10, 5, 0] 11 | -------------------------------------------------------------------------------- /configs/SD-2.1-base.yaml: -------------------------------------------------------------------------------- 1 | max_chunk_size: 16 2 | chunk_padding: 2 3 | model: 'SD2.1' 4 | encoding_guidance_scale: 0 5 | denoising_guidance_scale: 0 6 | encoding_timesteps: [972, 949, 929, 897, 869, 834, 805, 780, 751, 726, 704, 688, 670, 648, 627, 608, 591, 578, 561, 546, 530, 520, 510, 498, 491, 480, 465, 455, 447, 438, 429, 419, 410, 402, 390, 380, 371, 361, 353, 345, 336, 326, 319, 313, 305, 296, 289, 282, 276, 269, 261, 254, 247, 242, 237, 231, 224, 219, 213, 209, 204, 200, 194, 189, 185, 181, 175, 170, 167, 163, 160, 156, 153, 149, 146, 143, 139, 135, 132, 129, 125, 121, 118, 116, 113, 110, 107, 104, 101, 99, 96, 94, 92, 90, 87, 85, 82, 80, 78, 76, 74, 72, 70, 68, 66, 64, 62, 60, 58, 56, 54, 52, 50, 48, 47, 46, 45, 44, 43, 42, 41, 40, 39, 38, 37, 36, 35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1] 7 | manual_dkl_per_step: null 8 | # manual_dkl_per_step: [21, 20, 20, 40, 43, 66, 67, 67, 92, 92, 92, 74, 91, 123, 130, 128, 125, 102, 143, 135, 155, 102, 107, 134, 81, 133, 193, 135, 112, 130, 136, 157, 147, 135, 213, 186, 174, 202, 168, 174, 205, 237, 172, 152, 210, 247, 198, 206, 181, 219, 261, 237, 246, 181, 186, 231, 281, 207, 256, 175, 225, 186, 289, 250, 205, 211, 329, 285, 175, 241, 185, 254, 195, 268, 206, 212, 292, 302, 233, 239, 332, 345, 267, 182, 281, 290, 300, 311, 322, 220, 342, 234, 240, 246, 384, 263, 411, 284, 292, 301, 311, 321, 332, 344, 356, 370, 384, 399, 416, 433, 452, 473, 495, 520, 268, 275, 282, 290, 298, 306, 315, 325, 334, 345, 356, 368, 381, 394, 408, 424, 440, 457, 476, 496, 518, 542, 567, 596, 626, 660, 698, 740, 786, 838, 895, 962, 1039, 1125, 1228, 1345, 1488, 1663, 1877, 2148, 2506, 2992, 3674, 4715, 6488, 10135, 21339] 9 | recon_timesteps: [900, 800, 700, 600, 500, 400, 300, 200, 100, 90, 80, 70, 60, 50, 40, 30, 20, 10] 10 | denoising_timesteps: [981, 961, 941, 921, 901, 881, 861, 841, 821, 801, 781, 761, 741, 721, 701, 681, 661, 641, 621, 601, 581, 561, 541, 521, 501, 481, 461, 441, 421, 401, 381, 361, 341, 321, 301, 281, 261, 241, 221, 201, 181, 161, 141, 121, 101, 81, 61, 41, 21, 10, 5, 0] 11 | -------------------------------------------------------------------------------- /configs/SDXL-base.yaml: -------------------------------------------------------------------------------- 1 | max_chunk_size: 16 2 | chunk_padding: 2 3 | model: 'SDXL' 4 | use_refiner: true 5 | encoding_guidance_scale: 0 6 | denoising_guidance_scale: 0 7 | encoding_timesteps: [972, 949, 929, 897, 869, 834, 805, 780, 751, 726, 704, 688, 670, 648, 627, 608, 591, 578, 561, 546, 530, 520, 510, 498, 491, 480, 465, 455, 447, 438, 429, 419, 410, 402, 390, 380, 371, 361, 353, 345, 336, 326, 319, 313, 305, 296, 289, 282, 276, 269, 261, 254, 247, 242, 237, 231, 224, 219, 213, 209, 204, 200, 194, 189, 185, 181, 175, 170, 167, 163, 160, 156, 153, 149, 146, 143, 139, 135, 132, 129, 125, 121, 118, 116, 113, 110, 107, 104, 101, 99, 96, 94, 92, 90, 87, 85, 82, 80, 78, 76, 74, 72, 70, 68, 66, 64, 62, 60, 58, 56, 54, 52, 50, 48, 47, 46, 45, 44, 43, 42, 41, 40, 39, 38, 37, 36, 35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1] 8 | manual_dkl_per_step: null 9 | #manual_dkl_per_step: [64, 57, 56, 109, 112, 169, 167, 163, 221, 224, 226, 182, 226, 310, 333, 335, 328, 272, 383, 367, 424, 283, 297, 376, 230, 379, 553, 392, 327, 384, 402, 469, 443, 411, 653, 573, 541, 631, 526, 549, 645, 756, 550, 487, 677, 801, 648, 673, 597, 724, 863, 787, 820, 604, 622, 770, 936, 688, 854, 586, 754, 619, 975, 840, 690, 707, 1100, 948, 581, 795, 608, 834, 639, 876, 671, 688, 946, 975, 747, 765, 1052, 1086, 832, 564, 866, 887, 910, 936, 961, 652, 1007, 682, 692, 707, 1092, 741, 1145, 775, 792, 809, 827, 847, 867, 888, 906, 928, 953, 980, 1004, 1030, 1061, 1093, 1130, 1165, 591, 601, 611, 622, 632, 642, 654, 666, 680, 692, 709, 722, 741, 757, 775, 792, 815, 837, 861, 887, 917, 955, 986, 1024, 1064, 1106, 1151, 1210, 1270, 1342, 1426, 1522, 1647, 1771, 1929, 2116, 2338, 2619, 2978, 3438, 4045, 4895, 6112, 8010, 11330, 18360, 40906] 10 | recon_timesteps: [900, 800, 700, 600, 500, 400, 300, 200, 100, 90, 80, 70, 60, 50, 40, 30, 20, 10] 11 | denoising_timesteps: [981, 961, 941, 921, 901, 881, 861, 841, 821, 801, 781, 761, 741, 721, 701, 681, 661, 641, 621, 601, 581, 561, 541, 521, 501, 481, 461, 441, 421, 401, 381, 361, 341, 321, 301, 281, 261, 241, 221, 201, 181, 161, 141, 121, 101, 81, 61, 41, 21, 10, 5, 0] 12 | -------------------------------------------------------------------------------- /configs/Flux-base.yaml: -------------------------------------------------------------------------------- 1 | max_chunk_size: 16 2 | chunk_padding: 2 3 | model: 'Flux' 4 | use_refiner: true 5 | encoding_guidance_scale: 0 6 | denoising_guidance_scale: 0 7 | encoding_timesteps: [972, 949, 929, 897, 869, 834, 805, 780, 751, 726, 704, 688, 670, 648, 627, 608, 591, 578, 561, 546, 530, 520, 510, 498, 491, 480, 465, 455, 447, 438, 429, 419, 410, 402, 390, 380, 371, 361, 353, 345, 336, 326, 319, 313, 305, 296, 289, 282, 276, 269, 261, 254, 247, 242, 237, 231, 224, 219, 213, 209, 204, 200, 194, 189, 185, 181, 175, 170, 167, 163, 160, 156, 153, 149, 146, 143, 139, 135, 132, 129, 125, 121, 118, 116, 113, 110, 107, 104, 101, 99, 96, 94, 92, 90, 87, 85, 82, 80, 78, 76, 74, 72, 70, 68, 66, 64, 62, 60, 58, 56, 54, 52, 50, 48, 47, 46, 45, 44, 43, 42, 41, 40, 39, 38, 37, 36, 35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10] 8 | manual_dkl_per_step: null 9 | #manual_dkl_per_step: [73, 121, 162, 355, 418, 736, 688, 727, 1012, 1042, 1061, 877, 1103, 1538, 1681, 1720, 1727, 1449, 2080, 2029, 2383, 1599, 1701, 2192, 1344, 2246, 3335, 2393, 2021, 2397, 2533, 2994, 2853, 2668, 4279, 3797, 3626, 4292, 3619, 3811, 4538, 5376, 3949, 3523, 4944, 5896, 4798, 5030, 4496, 5487, 6624, 6080, 6391, 4724, 4893, 6135, 7518, 5552, 6967, 4773, 6204, 5098, 8004, 6915, 5688, 5850, 9208, 7978, 4880, 6719, 5152, 7110, 5452, 7527, 5767, 5894, 8163, 8451, 6487, 6683, 9270, 9621, 7416, 5034, 7806, 8063, 8365, 8648, 8971, 6121, 9551, 6535, 6706, 6907, 10831, 7459, 11734, 8056, 8343, 8624, 8925, 9256, 9599, 9997, 10435, 10886, 11395, 11903, 12463, 13094, 13760, 14476, 15308, 16185, 8363, 8617, 8882, 9144, 9440, 9737, 10084, 10433, 10788, 11178, 11576, 11981, 12473, 12970, 13488, 14042, 14665, 15304, 16009, 16769, 17574, 18504, 19467, 20484, 21667, 22954, 24359, 25964, 27756, 29768, 32126, 34780, 37722, 41421, 45459, 50706, 56769, 64347] 10 | recon_timesteps: [900, 800, 700, 600, 500, 400, 300, 200, 100, 90, 80, 70, 60, 50, 40, 30, 20, 10] 11 | denoising_timesteps: [981, 961, 941, 921, 901, 881, 861, 841, 821, 801, 781, 761, 741, 721, 701, 681, 661, 641, 621, 601, 581, 561, 541, 521, 501, 481, 461, 441, 421, 401, 381, 361, 341, 321, 301, 281, 261, 241, 221, 201, 181, 161, 141, 121, 101, 81, 61, 41, 21, 10, 5, 0] 12 | -------------------------------------------------------------------------------- /lib/blip.py: -------------------------------------------------------------------------------- 1 | from transformers import Blip2Processor, Blip2ForConditionalGeneration 2 | from PIL import Image 3 | import torch 4 | from pathlib import Path 5 | import pandas as pd 6 | 7 | 8 | class BlipCaptioner: 9 | def __init__( 10 | self, model_name: str = "Salesforce/blip2-opt-2.7b-coco", max_length: int = 75 11 | ): 12 | """Initialize BLIP-2 model for batch captioning.""" 13 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 14 | self.processor = Blip2Processor.from_pretrained(model_name) 15 | self.model = Blip2ForConditionalGeneration.from_pretrained(model_name).to( 16 | self.device 17 | ) 18 | self.max_length = max_length 19 | 20 | def generate_caption(self, image: Image.Image) -> str: 21 | """Generate caption for a single image.""" 22 | inputs = self.processor(images=image, return_tensors="pt").to(self.device) 23 | generated_ids = self.model.generate(**inputs, max_length=self.max_length) 24 | caption = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[ 25 | 0 26 | ].strip() 27 | return caption 28 | 29 | def process_images(self, image_paths) -> dict: 30 | """Process multiple images and return a dictionary of their captions.""" 31 | captions = {} 32 | for path in image_paths: 33 | img = Image.open(path).convert("RGB") 34 | caption = self.generate_caption(img) 35 | captions[str(path)] = caption 36 | return captions 37 | 38 | def process_and_save(self, image_paths, output_dir: Path) -> dict: 39 | """Process images, save captions to CSV, and return caption dictionary.""" 40 | captions = self.process_images(image_paths) 41 | 42 | # Create DataFrame and save to CSV 43 | df = pd.DataFrame( 44 | [{"image_path": k, "caption": v} for k, v in captions.items()] 45 | ) 46 | df.to_csv(output_dir / "blip_captions.csv", index=False) 47 | 48 | return captions 49 | 50 | def __del__(self): 51 | """Cleanup any GPU memory.""" 52 | if hasattr(self, "model"): 53 | del self.model 54 | if hasattr(self, "processor"): 55 | del self.processor 56 | if torch.cuda.is_available(): 57 | torch.cuda.empty_cache() 58 | -------------------------------------------------------------------------------- /lib/diffc/rcc/pfr.py: -------------------------------------------------------------------------------- 1 | import cupy as cp 2 | import numpy as np 3 | 4 | # Load the CUDA module 5 | cuda_code = open("lib/diffc/rcc/cuda_kernels.cu", "r").read() 6 | cuda_module = cp.RawModule(code=cuda_code) 7 | 8 | # Get the kernel functions 9 | reverse_channel_encode_kernel = cuda_module.get_function( 10 | "reverse_channel_encode_kernel" 11 | ) 12 | generate_sample_kernel = cuda_module.get_function("generate_sample_kernel") 13 | 14 | 15 | def generate_sample(dim, shared_seed, sample_seed): 16 | sample_out = cp.empty(dim, dtype=cp.float32) 17 | 18 | generate_sample_kernel( 19 | (1, 1, 1), 20 | (1, 1, 1), 21 | (cp.int32(dim), cp.uint64(shared_seed), cp.uint64(sample_seed), sample_out), 22 | ) 23 | 24 | return sample_out.get() 25 | 26 | 27 | def _reverse_channel_encode(mu_q_in, K, shared_seed=0): 28 | mu_q = cp.asarray(mu_q_in, dtype=cp.float32) 29 | dim = mu_q.shape[0] 30 | 31 | # Allocate memory on GPU 32 | log_w = cp.empty(K, dtype=cp.float32) 33 | max_log_w = cp.array([-cp.inf], dtype=cp.float32) 34 | 35 | # Set up grid and block dimensions 36 | block_size = 256 37 | grid_size = (K + block_size - 1) // block_size 38 | 39 | # Generate vector of random exponentials 40 | t = cp.random.exponential(scale=1.0, size=K) 41 | # take the log of the cumsum of those 42 | log_cumsum_t = cp.log(cp.cumsum(t)) 43 | 44 | # Launch main kernel 45 | reverse_channel_encode_kernel( 46 | (grid_size, 1, 1), 47 | (block_size, 1, 1), 48 | (mu_q, cp.int32(dim), cp.uint64(K), cp.uint64(shared_seed), log_w, max_log_w), 49 | ) 50 | cp.cuda.stream.get_current_stream().synchronize() 51 | 52 | s = log_cumsum_t - log_w 53 | 54 | winning_seed = cp.argmin(s).item() 55 | sample = generate_sample(dim, shared_seed, winning_seed) 56 | 57 | return winning_seed, sample.astype(np.float16) 58 | 59 | 60 | def _reverse_channel_decode(dim, shared_seed, winning_seed): 61 | sample = generate_sample(dim, shared_seed, winning_seed) 62 | return sample.astype(np.float16) 63 | 64 | 65 | def reverse_channel_encode(mu_q, K=None, shared_seed=0): 66 | diff = (mu_q).astype(np.float32) # Convert to float32 67 | seed, sample = _reverse_channel_encode(diff, K, shared_seed) 68 | return seed, sample 69 | 70 | 71 | def reverse_channel_decode(dim, seed, shared_seed=0): 72 | """ 73 | Given an isotropic gaussian with unit variance centered at mu_q, 74 | and a random seed, generate a sample from the distribution q. 75 | """ 76 | sample = _reverse_channel_decode(dim, shared_seed, seed) 77 | return sample.astype(np.float16) 78 | -------------------------------------------------------------------------------- /lib/diffc/denoise.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | from lib.diffc.utils.alpha_beta import get_alpha_prod_and_beta_prod 3 | import torch 4 | 5 | @torch.no_grad() 6 | def denoise(noisy_latent, latent_timestep, timestep_schedule, noise_prediction_model): 7 | """ 8 | Perform probability-flow-based denoising upon the noisy latent. 9 | 10 | Args: 11 | noisy_latent: latent to be denoised. 12 | latent_SNR: signal to noise ratio of the latent to be denoised. 13 | SNR_schedule (List[float]): List of signal-to-noise ratios in decreasing order, 14 | matching the schedule used during encoding. Last element should be 0 for fully denoised image. 15 | predict_noise (callable): Function that predicts the noise component given a noisy 16 | latent and its SNR. 17 | 18 | """ 19 | latent = noisy_latent 20 | current_timestep = latent_timestep 21 | current_snr = noise_prediction_model.get_timestep_snr(current_timestep) 22 | 23 | timestep_schedule = [t for t in timestep_schedule if t < latent_timestep] 24 | 25 | for prev_timestep in tqdm( 26 | timestep_schedule 27 | ): # "previous" as in higher than the current snr 28 | noise_prediction = noise_prediction_model.predict_noise( 29 | latent.to(noise_prediction_model.dtype), current_timestep 30 | ).to(torch.float32) 31 | prev_snr = noise_prediction_model.get_timestep_snr(prev_timestep) 32 | 33 | alpha_prod_t, beta_prod_t = get_alpha_prod_and_beta_prod(current_snr) 34 | alpha_prod_t_prev, beta_prod_t_prev = get_alpha_prod_and_beta_prod(prev_snr) 35 | 36 | # if int(prev_timestep) == 0: 37 | # from IPython.core.debugger import set_trace 38 | # set_trace() 39 | 40 | beta_prod_t = 1 - alpha_prod_t 41 | 42 | # 3. compute predicted original sample from predicted noise also called 43 | # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf 44 | sample = latent 45 | model_output = noise_prediction 46 | pred_original_sample = ( 47 | sample - beta_prod_t ** (0.5) * model_output 48 | ) / alpha_prod_t ** (0.5) 49 | pred_epsilon = model_output 50 | 51 | # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf 52 | pred_sample_direction = (1 - alpha_prod_t_prev) ** (0.5) * pred_epsilon 53 | 54 | # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf 55 | latent = ( 56 | alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction 57 | ) 58 | 59 | current_timestep = prev_timestep 60 | current_snr = prev_snr 61 | 62 | return latent.to(noisy_latent.dtype) 63 | -------------------------------------------------------------------------------- /lib/diffc/decode.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from lib.diffc.utils.p import P 3 | from tqdm import tqdm 4 | 5 | @torch.no_grad() 6 | def decode( 7 | image_width, 8 | image_height, 9 | timestep_schedule, 10 | noise_prediction_model, 11 | gaussian_channel_simulator, 12 | chunk_seeds_per_step, 13 | Dkl_per_step, 14 | seed, 15 | ): 16 | """Decodes a compressed image representation back into its latent space form. 17 | 18 | Args: 19 | latent_shape (torch.Size): Shape of the target latent tensor to be reconstructed. 20 | timestep_schedule (List[float]): List of timesteps in decreasing order. 21 | predict_noise (callable): Function that predicts the noise component given a noisy 22 | latent and its SNR. 23 | gaussian_channel_simulator: Simulator used for gaussian channel reconstruction. 24 | chunk_seeds_per_step (List[List[int]]): Compressed representation of the image, 25 | consisting of lists of integer seeds for each denoising step. 26 | Dkl_per_step (List[float]): List of Kullback-Leibler divergence values per step, 27 | used to reconstruct the denoising process. 28 | seed (int): Random seed for reproducibility of the denoising process. 29 | 30 | Returns: 31 | torch.Tensor: The reconstructed latent representation of the image, obtained 32 | through progressive denoising steps guided by the compressed representation. 33 | """ 34 | 35 | 36 | device = noise_prediction_model.device 37 | dtype = noise_prediction_model.dtype 38 | 39 | dummy_image = torch.zeros((1, 3, image_height, image_width)).to(device).to(dtype) 40 | dummy_latent = noise_prediction_model.image_to_latent(dummy_image) 41 | 42 | torch.manual_seed(seed) 43 | noisy_latent = torch.randn(dummy_latent.shape, device=device, dtype=dtype) 44 | 45 | current_timestep = 1000 46 | current_snr = noise_prediction_model.get_timestep_snr(current_timestep) 47 | for step_index, (prev_timestep, chunk_seeds, Dkl) in tqdm(enumerate( 48 | zip(timestep_schedule, chunk_seeds_per_step, Dkl_per_step) 49 | ), total=len(chunk_seeds_per_step)): 50 | noise_prediction = noise_prediction_model.predict_noise( 51 | noisy_latent, current_timestep 52 | ) 53 | prev_snr = noise_prediction_model.get_timestep_snr(prev_timestep) 54 | p_mu, std = P(noisy_latent, noise_prediction, current_snr, prev_snr) 55 | sample = gaussian_channel_simulator.decode( 56 | chunk_seeds, noisy_latent.numel(), Dkl, seed=step_index 57 | ) 58 | reshaped_sample = ( 59 | torch.tensor(sample).reshape(noisy_latent.shape).to(device).to(dtype) 60 | ) 61 | noisy_latent = reshaped_sample * std + p_mu 62 | current_timestep = prev_timestep 63 | current_snr = prev_snr 64 | 65 | return noisy_latent 66 | -------------------------------------------------------------------------------- /lib/diffc/rcc/chunk_coding.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from lib.diffc.rcc.pfr import reverse_channel_encode, reverse_channel_decode 3 | 4 | 5 | def partition_mu(dim, chunk_sizes, shared_seed=0): 6 | """ 7 | return an array of shape (dim,) which determines which chunk each dimension belongs to. 8 | the values in the array correspond to the indices of the chunk. 9 | """ 10 | total_bits = sum(chunk_sizes) 11 | chunk_ndims = [] 12 | for chunk_size in chunk_sizes[:-1]: 13 | chunk_ndims.append(int(dim * chunk_size / total_bits)) 14 | chunk_ndims.append(dim - sum(chunk_ndims)) 15 | 16 | partition_indices = np.concatenate( 17 | [np.full(ndims, i) for i, ndims in enumerate(chunk_ndims)] 18 | ) 19 | rng = np.random.default_rng(shared_seed) 20 | rng.shuffle(partition_indices) 21 | 22 | return partition_indices 23 | 24 | 25 | def combine_partitions(partition_indices, partitions): 26 | combined = np.zeros_like(partition_indices, dtype=partitions[0].dtype) 27 | for i, partition in enumerate(partitions): 28 | combined[partition_indices == i] = partition 29 | return combined 30 | 31 | 32 | def chunk_and_encode(mu, chunk_sizes, shared_seed=0): 33 | partition_indices = partition_mu(len(mu), chunk_sizes, shared_seed) 34 | 35 | partitions = [] 36 | seeds = [] 37 | for i, chunk_size in enumerate(chunk_sizes): 38 | chunk_mask = partition_indices == i 39 | mu_chunk = mu[chunk_mask] 40 | chunk_shared_seed = hash((shared_seed, i)) % (2 ** 32) 41 | seed, partition = reverse_channel_encode( 42 | mu_chunk, K=int(2 ** chunk_size), shared_seed=chunk_shared_seed 43 | ) 44 | seeds.append(seed) 45 | partitions.append(partition) 46 | 47 | return tuple(seeds), combine_partitions(partition_indices, partitions) 48 | 49 | 50 | def decode_from_chunks(dim, seeds, chunk_sizes, shared_seed=0): 51 | partition_indices = partition_mu(dim, chunk_sizes, shared_seed) 52 | 53 | partitions = [] 54 | for i, seed in enumerate(seeds): 55 | chunk_shared_seed = hash((shared_seed, i)) % (2 ** 32) 56 | chunk_dim = (partition_indices == i).sum() 57 | partition = reverse_channel_decode( 58 | chunk_dim, seed, shared_seed=chunk_shared_seed 59 | ) 60 | partitions.append(partition) 61 | return combine_partitions(partition_indices, partitions) 62 | 63 | 64 | def distribute_apples(m, n): 65 | """ 66 | Given m apples and n buckets, return how many apples to put in each bucket, to distribute as evenly as possible. 67 | """ 68 | if n == 0: 69 | return [] 70 | 71 | base_apples = m // n 72 | extra_apples = m % n 73 | 74 | distribution = [base_apples] * n 75 | 76 | for i in range(extra_apples): 77 | distribution[i] += 1 78 | 79 | return tuple(distribution) 80 | 81 | 82 | def get_chunk_sizes(Dkl, max_size=8, chunk_padding_bits=2): 83 | n = int(np.ceil(Dkl)) 84 | num_chunks = int(np.ceil(n / (max_size - chunk_padding_bits))) 85 | return distribute_apples(n + chunk_padding_bits * num_chunks, num_chunks) 86 | -------------------------------------------------------------------------------- /lib/metrics.py: -------------------------------------------------------------------------------- 1 | from lib import image_utils 2 | import torch 3 | 4 | 5 | def get_bpp(seed_tuples, zipf_s_vals, zipf_n_vals, recon_step_idx, num_pixels): 6 | from zipf_encoding import encode_zipf 7 | 8 | # get seeds up to and including the recon_step_idx 9 | recon_seeds = sum(map(list, seed_tuples[: recon_step_idx + 1]), []) 10 | 11 | encoding = encode_zipf( 12 | zipf_s_vals[: len(recon_seeds)], zipf_n_vals[: len(recon_seeds)], recon_seeds 13 | ) 14 | 15 | num_bits = len(encoding) * 8 16 | 17 | bpp = num_bits / num_pixels 18 | return bpp 19 | 20 | 21 | def get_psnr(recon, gt_pt): 22 | from skimage.metrics import peak_signal_noise_ratio 23 | 24 | return peak_signal_noise_ratio( 25 | image_utils.torch_to_np_img(recon), image_utils.torch_to_np_img(gt_pt) 26 | ) 27 | 28 | 29 | _get_lpips = None 30 | 31 | 32 | def get_lpips(recon, gt): 33 | global _get_lpips 34 | if _get_lpips is None: 35 | from lpips import LPIPS 36 | 37 | _get_lpips = LPIPS(net="alex").to(torch.device("cuda")) 38 | return _get_lpips(recon * 2 - 1, gt * 2 - 1).item() 39 | 40 | 41 | ############################################################################### 42 | ## CLIP 43 | ############################################################################### 44 | 45 | 46 | def load_clip_model(model_name="ViT-B/32", device=None): 47 | import clip 48 | if device is None: 49 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 50 | model, preprocess = clip.load(model_name, device=device) 51 | return model, preprocess, device 52 | 53 | 54 | def preprocess_image(image, preprocess): 55 | return preprocess(image).unsqueeze(0) 56 | 57 | 58 | @torch.no_grad() 59 | def clip_score(image_a, image_b, model=None, preprocess=None, device=None): 60 | if model is None or preprocess is None or device is None: 61 | model, preprocess, device = load_clip_model() 62 | 63 | # Preprocess images 64 | image_a_preprocessed = preprocess_image(image_a, preprocess).to(device) 65 | image_b_preprocessed = preprocess_image(image_b, preprocess).to(device) 66 | 67 | # Encode images 68 | image_a_features = model.encode_image(image_a_preprocessed) 69 | image_b_features = model.encode_image(image_b_preprocessed) 70 | 71 | # Normalize features 72 | image_a_features = image_a_features / image_a_features.norm(dim=1, keepdim=True) 73 | image_b_features = image_b_features / image_b_features.norm(dim=1, keepdim=True) 74 | 75 | # Calculate CLIP score 76 | logit_scale = model.logit_scale.exp() 77 | score = logit_scale * (image_a_features * image_b_features).sum() 78 | 79 | return score.item() 80 | 81 | 82 | model, preprocess, device = None, None, None 83 | 84 | 85 | def get_clip_score(recon, gt): 86 | global model, preprocess, device 87 | if model is None: 88 | model, preprocess, device = load_clip_model() 89 | return clip_score( 90 | image_utils.torch_to_pil_img(recon), 91 | image_utils.torch_to_pil_img(gt), 92 | model, 93 | preprocess, 94 | device, 95 | ) 96 | -------------------------------------------------------------------------------- /lib/diffc/rcc/gaussian_channel_simulator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from lib.diffc.rcc.chunk_coding import ( 3 | get_chunk_sizes, 4 | chunk_and_encode, 5 | decode_from_chunks, 6 | ) 7 | import numpy as np 8 | from zipf_encoding import encode_zipf, decode_zipf 9 | 10 | 11 | class GaussianChannelSimulator: 12 | def __init__(self, max_chunk_size, chunk_padding): 13 | self.max_chunk_size = max_chunk_size 14 | self.chunk_padding = chunk_padding 15 | 16 | def encode(self, mu, manual_dkl=None, seed=0): 17 | """Simulates a noisy channel with identity covariance and mean mu.""" 18 | dkl = manual_dkl 19 | if dkl is None: 20 | dkl = 0.5 * float((mu.astype(np.float32) ** 2).sum() / np.log(2)) 21 | 22 | chunk_sizes = get_chunk_sizes(dkl, self.max_chunk_size, self.chunk_padding) 23 | chunk_seeds, sample = chunk_and_encode( 24 | mu, chunk_sizes=chunk_sizes, shared_seed=seed 25 | ) 26 | 27 | return sample, chunk_seeds, dkl 28 | 29 | def decode(self, chunk_seeds, dim, dkl, seed=0): 30 | chunk_sizes = get_chunk_sizes(dkl, self.max_chunk_size, self.chunk_padding) 31 | return decode_from_chunks(dim, chunk_seeds, chunk_sizes, seed) 32 | 33 | def compress_chunk_seeds(self, chunk_seeds_per_step, dkl_per_step): 34 | zipf_s_vals = [] 35 | zipf_n_vals = [] 36 | seeds = [] 37 | 38 | for chunk_seeds, dkl in zip(chunk_seeds_per_step, dkl_per_step): 39 | chunk_sizes = get_chunk_sizes(dkl, self.max_chunk_size, self.chunk_padding) 40 | chunk_size_sum = sum(chunk_sizes) 41 | for chunk_seed, chunk_size in zip(chunk_seeds, chunk_sizes): 42 | zipf_n_vals.append(2 ** chunk_size) 43 | 44 | chunk_dkl = dkl * chunk_size / chunk_size_sum 45 | s = 1 + 1 / (chunk_dkl + np.exp(-1) * np.log(np.e + 1)) 46 | zipf_s_vals.append(s) 47 | seeds.append(chunk_seed) 48 | 49 | return encode_zipf(zipf_s_vals, zipf_n_vals, seeds) 50 | 51 | def decompress_chunk_seeds(self, encoded_bytes, dkl_per_step): 52 | zipf_s_vals = [] 53 | zipf_n_vals = [] 54 | 55 | for dkl in dkl_per_step: 56 | chunk_sizes = get_chunk_sizes(dkl, self.max_chunk_size, self.chunk_padding) 57 | chunk_size_sum = sum(chunk_sizes) 58 | for chunk_size in chunk_sizes: 59 | zipf_n_vals.append(int(2 ** chunk_size)) 60 | 61 | chunk_dkl = dkl * chunk_size / chunk_size_sum 62 | s = 1 + 1 / (chunk_dkl + np.exp(-1) * np.log(np.e + 1)) 63 | zipf_s_vals.append(s) 64 | 65 | flattened_seeds = decode_zipf(zipf_s_vals, zipf_n_vals, encoded_bytes) 66 | chunk_seeds_per_step = [] 67 | index = 0 68 | for dkl in dkl_per_step: 69 | chunk_sizes = get_chunk_sizes(dkl, self.max_chunk_size, self.chunk_padding) 70 | step_seeds = [] 71 | for chunk_size in chunk_sizes: 72 | step_seeds.append(flattened_seeds[index]) 73 | index += 1 74 | chunk_seeds_per_step.append(step_seeds) 75 | return chunk_seeds_per_step 76 | 77 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Lossy Compression with Pretrained Diffusion Models 2 | 3 | Official implementation of our ICLR 2025 paper [Lossy Compression with Pretrained Diffusion Models](https://arxiv.org/abs/2501.09815) by Jeremy Vonderfect and Feng Liu. See our [project page](https://jeremyiv.github.io/diffc-project-page/) for an interactive demo of results. 4 | 5 | ## Abstract 6 | 7 | We present a lossy compression method that can leverage state-of-the-art diffusion models for entropy coding. Our method works _zero-shot_, requiring no additional training of the diffusion model or any ancillary networks. We apply the DiffC algorithm[^1] to 8 | [Stable Diffusion](https://huggingface.co/stabilityai/stable-diffusion-2-1) 1.5, 2.1, XL, and [Flux-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev). 9 | We demonstrate that our method is competitive with other state-of-the-art generative compression methods at low ultra-low bitrates. 10 | 11 | ## Results 12 | 13 | We compare our method (DiffC) against [PerCo](https://github.com/Nikolai10/PerCo), [DiffEIC](https://github.com/huai-chang/DiffEIC), [HiFiC](https://github.com/Justin-Tan/high-fidelity-generative-compression), and [MS-ILLM](https://github.com/facebookresearch/NeuralCompression/tree/main/projects/illm). 14 | 15 | ![Visual Comparison](figures/visual-comparison.png) 16 | 17 | In the following rate-distortion curves, SD1.5, SD2.1, SDXL, and Flux represent the DiffC algorithm with those respective diffusion models. The dashed horizontal 'VAE' lines represent the best achievable metrics given the fidelity of the model's variational autoencoder. 18 | 19 | ![Kodak RD curves](figures/kodak-rd-curves-Qalign.png) 20 | ![Div2k RD curves](figures/div2k-1024-rd-curves-Qalign.png) 21 | 22 | ## Setup 23 | 24 | ``` 25 | git clone https://github.com/JeremyIV/diffc.git 26 | cd diffc 27 | conda env create -f environment.yml 28 | conda activate diffc 29 | ``` 30 | 31 | ## Usage 32 | 33 | ``` 34 | python evaluate.py --config configs/SD-1.5-base.yaml --image_dir data/kodak --output_dir results/SD-1.5-base/kodak 35 | ``` 36 | 37 | To save the compressed representation of an image as a `diffc` file, use 38 | 39 | ``` 40 | python compress.py --config configs/SD-1.5-base.yaml --image_dir data/kodak --output_dir results/SD-1.5-base/kodak/compressed --recon_timestep 200 41 | ``` 42 | 43 | To reconstruct an image/images from their compressed representations, use 44 | 45 | ``` 46 | python decompress.py --config configs/SD-1.5-base.yaml --input_dir results/SD-1.5-base/kodak/compressed --output_dir results/SD-1.5-base/kodak/reconstructions 47 | ``` 48 | 49 | Note that currently, compress and decompress.py only work with `SD-1.5-base.yaml`. To make them work with the other configs, you would need to specify `manual_dkl_per_step` in the config file. 50 | 51 | ## Citation 52 | 53 | ```bibtex 54 | @inproceedings{ 55 | vonderfecht2025lossy, 56 | title={Lossy Compression with Pretrained Diffusion Models}, 57 | author={Jeremy Vonderfecht and Feng Liu}, 58 | booktitle={The Thirteenth International Conference on Learning Representations}, 59 | year={2025}, 60 | url={https://openreview.net/forum?id=raUnLe0Z04} 61 | } 62 | ``` 63 | 64 | ## Acknowledgements 65 | 66 | Thanks to https://github.com/danieleades/arithmetic-coding for the entropy coding library. 67 | 68 | [^1]: Theis, L., Salimans, T., Hoffman, M. D., & Mentzer, F. (2022). [Lossy compression with gaussian diffusion](https://arxiv.org/abs/2206.08889). arXiv preprint arXiv:2206.08889. 69 | [^2]: Ho, J., Jain, A., & Abbeel, P. (2020). [Denoising diffusion probabilistic models](https://arxiv.org/abs/2006.11239). Advances in Neural Information Processing Systems, 33, 6840-6851. 70 | -------------------------------------------------------------------------------- /notebooks/blip-prompts.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 3, 6 | "id": "297528c9", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "from transformers import Blip2Processor, Blip2ForConditionalGeneration\n", 11 | "from PIL import Image\n", 12 | "import torch\n", 13 | "\n", 14 | "def get_blip2_caption(image: Image.Image, model_name: str = \"Salesforce/blip2-opt-2.7b-coco\", max_length: int = 75) -> str:\n", 15 | " \"\"\"\n", 16 | " Generate a caption for an image using BLIP-2.\n", 17 | " \n", 18 | " Args:\n", 19 | " image: PIL Image in RGB format\n", 20 | " model_name: BLIP-2 model to use\n", 21 | " max_length: Maximum length of generated caption in tokens\n", 22 | " \n", 23 | " Returns:\n", 24 | " str: Generated caption for the image\n", 25 | " \"\"\"\n", 26 | " device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", 27 | " \n", 28 | " # Initialize model and processor\n", 29 | " processor = Blip2Processor.from_pretrained(model_name)\n", 30 | " model = Blip2ForConditionalGeneration.from_pretrained(model_name).to(device)\n", 31 | " \n", 32 | " # Prepare image and generate caption\n", 33 | " inputs = processor(images=image, return_tensors=\"pt\").to(device)\n", 34 | " generated_ids = model.generate(**inputs, max_length=max_length)\n", 35 | " caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()\n", 36 | " \n", 37 | " return caption" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 4, 43 | "id": "5f3b958e", 44 | "metadata": {}, 45 | "outputs": [ 46 | { 47 | "data": { 48 | "application/vnd.jupyter.widget-view+json": { 49 | "model_id": "3fdbcb2353be4e5192455c5f3de2fd3e", 50 | "version_major": 2, 51 | "version_minor": 0 52 | }, 53 | "text/plain": [ 54 | "Loading checkpoint shards: 0%| | 0/2 [00:00) -> Result, !> { 37 | /// Ok(match symbol { 38 | /// None => 0..1, 39 | /// Some(&Symbol::A) => 1..2, 40 | /// Some(&Symbol::B) => 2..3, 41 | /// Some(&Symbol::C) => 3..4, 42 | /// }) 43 | /// } 44 | /// 45 | /// fn symbol(&self, value: Self::B) -> Option { 46 | /// match value { 47 | /// 0..1 => None, 48 | /// 1..2 => Some(Symbol::A), 49 | /// 2..3 => Some(Symbol::B), 50 | /// 3..4 => Some(Symbol::C), 51 | /// _ => unreachable!(), 52 | /// } 53 | /// } 54 | /// 55 | /// fn max_denominator(&self) -> u32 { 56 | /// 4 57 | /// } 58 | /// } 59 | /// ``` 60 | pub trait Model { 61 | /// The type of symbol this [`Model`] describes 62 | type Symbol; 63 | 64 | /// Invalid symbol error 65 | type ValueError: Error; 66 | 67 | /// The internal representation to use for storing integers 68 | type B: BitStore; 69 | 70 | /// Given a symbol, return an interval representing the probability of that 71 | /// symbol occurring. 72 | /// 73 | /// This is given as a range, over the denominator given by 74 | /// [`Model::denominator`]. This range should in general include `EOF`, 75 | /// which is denoted by `None`. 76 | /// 77 | /// For example, from the set {heads, tails}, the interval representing 78 | /// heads could be `0..1`, and tails would be `1..2`, and `EOF` could be 79 | /// `2..3` (with a denominator of `3`). 80 | /// 81 | /// This is the inverse of the [`Model::symbol`] method 82 | /// 83 | /// # Errors 84 | /// 85 | /// This returns a custom error if the given symbol is not valid 86 | fn probability( 87 | &self, 88 | symbol: Option<&Self::Symbol>, 89 | ) -> Result, Self::ValueError>; 90 | 91 | /// The denominator for probability ranges. See [`Model::probability`]. 92 | /// 93 | /// By default this method simply returns the [`Model::max_denominator`], 94 | /// which is suitable for non-adaptive models. 95 | /// 96 | /// In adaptive models this value may change, however it should never exceed 97 | /// [`Model::max_denominator`], or it becomes possible for the 98 | /// [`Encoder`](crate::Encoder) and [`Decoder`](crate::Decoder) to panic due 99 | /// to overflow or underflow. 100 | fn denominator(&self) -> Self::B { 101 | self.max_denominator() 102 | } 103 | 104 | /// The maximum denominator used for probability ranges. See 105 | /// [`Model::probability`]. 106 | /// 107 | /// This value is used to calculate an appropriate precision for the 108 | /// encoding, therefore this value must not change, and 109 | /// [`Model::denominator`] must never exceed it. 110 | fn max_denominator(&self) -> Self::B; 111 | 112 | /// Given a value, return the symbol whose probability range it falls in. 113 | /// 114 | /// `None` indicates `EOF` 115 | /// 116 | /// This is the inverse of the [`Model::probability`] method 117 | fn symbol(&self, value: Self::B) -> Option; 118 | 119 | /// Update the current state of the model with the latest symbol. 120 | /// 121 | /// This method only needs to be implemented for 'adaptive' models. It's a 122 | /// no-op by default. 123 | fn update(&mut self, _symbol: Option<&Self::Symbol>) {} 124 | } 125 | -------------------------------------------------------------------------------- /lib/diffc/rcc/arithmetic-coding/arithmetic-coding-core/src/model/one_shot.rs: -------------------------------------------------------------------------------- 1 | //! Helper trait for creating Models which only accept a single symbol 2 | 3 | use std::ops::Range; 4 | 5 | pub use crate::fixed_length::Wrapper; 6 | use crate::{fixed_length, BitStore}; 7 | 8 | /// A [`Model`] is used to calculate the probability of a given symbol occuring 9 | /// in a sequence. The [`Model`] is used both for encoding and decoding. A 10 | /// 'fixed-length' model always expects an exact number of symbols, and so does 11 | /// not need to encode an EOF symbol. 12 | /// 13 | /// A fixed length model can be converted into a regular model using the 14 | /// convenience [`Wrapper`] type. 15 | /// 16 | /// The more accurately a [`Model`] is able to predict the next symbol, the 17 | /// greater the compression ratio will be. 18 | /// 19 | /// # Example 20 | /// 21 | /// ``` 22 | /// #![feature(exclusive_range_pattern)] 23 | /// #![feature(never_type)] 24 | /// # use std::ops::Range; 25 | /// # 26 | /// # use arithmetic_coding_core::one_shot; 27 | /// 28 | /// pub enum Symbol { 29 | /// A, 30 | /// B, 31 | /// C, 32 | /// } 33 | /// 34 | /// pub struct MyModel; 35 | /// 36 | /// impl one_shot::Model for MyModel { 37 | /// type Symbol = Symbol; 38 | /// type ValueError = !; 39 | /// 40 | /// fn probability(&self, symbol: &Self::Symbol) -> Result, !> { 41 | /// Ok(match symbol { 42 | /// Symbol::A => 0..1, 43 | /// Symbol::B => 1..2, 44 | /// Symbol::C => 2..3, 45 | /// }) 46 | /// } 47 | /// 48 | /// fn symbol(&self, value: Self::B) -> Self::Symbol { 49 | /// match value { 50 | /// 0..1 => Symbol::A, 51 | /// 1..2 => Symbol::B, 52 | /// 2..3 => Symbol::C, 53 | /// _ => unreachable!(), 54 | /// } 55 | /// } 56 | /// 57 | /// fn max_denominator(&self) -> u32 { 58 | /// 3 59 | /// } 60 | /// } 61 | /// ``` 62 | pub trait Model { 63 | /// The type of symbol this [`Model`] describes 64 | type Symbol; 65 | 66 | /// Invalid symbol error 67 | type ValueError: std::error::Error; 68 | 69 | /// The internal representation to use for storing integers 70 | type B: BitStore; 71 | 72 | /// Given a symbol, return an interval representing the probability of that 73 | /// symbol occurring. 74 | /// 75 | /// This is given as a range, over the denominator given by 76 | /// [`Model::denominator`]. This range should in general include `EOF`, 77 | /// which is denoted by `None`. 78 | /// 79 | /// For example, from the set {heads, tails}, the interval representing 80 | /// heads could be `0..1`, and tails would be `1..2`, and `EOF` could be 81 | /// `2..3` (with a denominator of `3`). 82 | /// 83 | /// This is the inverse of the [`Model::symbol`] method 84 | /// 85 | /// # Errors 86 | /// 87 | /// This returns a custom error if the given symbol is not valid 88 | fn probability(&self, symbol: &Self::Symbol) -> Result, Self::ValueError>; 89 | 90 | /// The maximum denominator used for probability ranges. See 91 | /// [`Model::probability`]. 92 | /// 93 | /// This value is used to calculate an appropriate precision for the 94 | /// encoding, therefore this value must not change, and 95 | /// [`Model::denominator`] must never exceed it. 96 | fn max_denominator(&self) -> Self::B; 97 | 98 | /// Given a value, return the symbol whose probability range it falls in. 99 | /// 100 | /// `None` indicates `EOF` 101 | /// 102 | /// This is the inverse of the [`Model::probability`] method 103 | fn symbol(&self, value: Self::B) -> Self::Symbol; 104 | } 105 | 106 | impl fixed_length::Model for T 107 | where 108 | T: Model, 109 | { 110 | type B = T::B; 111 | type Symbol = T::Symbol; 112 | type ValueError = T::ValueError; 113 | 114 | fn probability(&self, symbol: &Self::Symbol) -> Result, Self::ValueError> { 115 | Model::probability(self, symbol) 116 | } 117 | 118 | fn max_denominator(&self) -> Self::B { 119 | self.max_denominator() 120 | } 121 | 122 | fn symbol(&self, value: Self::B) -> Self::Symbol { 123 | Model::symbol(self, value) 124 | } 125 | 126 | fn length(&self) -> usize { 127 | 1 128 | } 129 | 130 | fn denominator(&self) -> Self::B { 131 | self.max_denominator() 132 | } 133 | } 134 | -------------------------------------------------------------------------------- /lib/diffc/rcc/arithmetic-coding/python-bindings/src/zipf.rs: -------------------------------------------------------------------------------- 1 | //#![feature(exclusive_range_pattern)] 2 | 3 | use std::ops::Range; 4 | 5 | use arithmetic_coding::Model; 6 | 7 | //mod common; 8 | 9 | #[derive(Clone)] 10 | pub struct ZipfModel { 11 | s_values: Vec, 12 | n_values: Vec, 13 | current_index: usize, 14 | cmf: Vec, 15 | exhausted: bool, 16 | } 17 | 18 | 19 | impl ZipfModel { 20 | pub fn new(s_values: Vec, n_values: Vec) -> Self { 21 | assert_eq!(s_values.len(), n_values.len(), "s_values and n_values must have the same length"); 22 | assert!(!s_values.is_empty(), "At least one (s, n) pair is required"); 23 | 24 | let mut model = ZipfModel { 25 | s_values, 26 | n_values, 27 | current_index: 0, 28 | cmf: Vec::new(), 29 | exhausted: false, 30 | }; 31 | model.update_cmf(); 32 | model 33 | } 34 | 35 | fn update_cmf(&mut self) { 36 | let s = self.s_values[self.current_index]; 37 | let n = self.n_values[self.current_index]; 38 | 39 | let max_denom = Self::max_denominator(); 40 | self.cmf = vec![0; n as usize + 1]; 41 | 42 | let mut sum = 0.0; 43 | let harmonic_n: f64 = (1..=n).map(|k| 1.0 / (k as f64).powf(s)).sum(); 44 | 45 | for k in 1..=n { 46 | sum += 1.0 / (k as f64).powf(s) / harmonic_n; 47 | self.cmf[k as usize] = (sum * max_denom as f64).min(max_denom as f64) as u64; 48 | } 49 | 50 | // Ensure the last value is exactly max_denominator 51 | *self.cmf.last_mut().unwrap() = max_denom; 52 | } 53 | 54 | fn max_denominator() -> u64 { 55 | 1 << 24 56 | } 57 | 58 | fn is_significantly_different(a: f64, b: f64) -> bool { 59 | (a - b).abs() > 1e-10 60 | } 61 | 62 | fn should_update_cmf(&self) -> bool { 63 | let current_s = self.s_values[self.current_index]; 64 | let previous_s = self.s_values[self.current_index - 1]; 65 | let current_n = self.n_values[self.current_index]; 66 | let previous_n = self.n_values[self.current_index - 1]; 67 | 68 | Self::is_significantly_different(current_s, previous_s) || current_n != previous_n 69 | } 70 | } 71 | 72 | #[derive(Debug, thiserror::Error)] 73 | #[error("invalid symbol: {0}")] 74 | pub struct Error(pub u32); 75 | 76 | impl Model for ZipfModel { 77 | type Symbol = u32; 78 | type ValueError = Error; 79 | type B = u64; 80 | 81 | fn probability(&self, symbol: Option<&Self::Symbol>) -> Result, Error> { 82 | if self.exhausted { 83 | match symbol { 84 | None => Ok(0..Self::max_denominator()), 85 | Some(&s) => Err(Error(s)), 86 | } 87 | } else { 88 | match symbol { 89 | None => Err(Error(self.n_values[self.current_index])), // Error for None when not exhausted 90 | Some(&s) if s < self.n_values[self.current_index] => { 91 | Ok(self.cmf[s as usize]..self.cmf[s as usize + 1]) 92 | }, 93 | Some(&s) => Err(Error(s)), 94 | } 95 | } 96 | } 97 | 98 | fn symbol(&self, value: Self::B) -> Option { 99 | if self.exhausted { 100 | return None; 101 | } 102 | 103 | if value >= Self::max_denominator() { 104 | return None; 105 | } 106 | 107 | match self.cmf.binary_search(&value) { 108 | Ok(exact_index) => Some(exact_index as u32), // minus 1? 109 | Err(insertion_index) => Some(insertion_index as u32 - 1), 110 | } 111 | } 112 | 113 | fn max_denominator(&self) -> Self::B { 114 | Self::max_denominator() 115 | } 116 | 117 | fn update(&mut self, symbol: Option<&Self::Symbol>) { 118 | if self.exhausted { 119 | return; 120 | } 121 | 122 | if symbol.is_none() { 123 | self.exhausted = true; 124 | return; 125 | } 126 | 127 | self.current_index += 1; 128 | if self.current_index >= self.s_values.len() { 129 | self.exhausted = true; 130 | } else if self.should_update_cmf() { 131 | self.update_cmf(); 132 | } 133 | } 134 | } 135 | 136 | 137 | // fn main() { 138 | // let s_values = vec![1.0, 1.1, 1.2, 1.0, 1.0, 1.1, 1.2, 1.0, 1.1, 1.2, 1.0, 1.0, 1.1, 1.2]; 139 | // let n_values = vec![3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3]; 140 | // let model = ZipfModel::new(s_values, n_values); 141 | 142 | // common::round_trip(model, vec![2, 1, 1, 2, 2, 0, 1, 2, 1, 1, 2, 2, 0, 1]); 143 | // } 144 | -------------------------------------------------------------------------------- /lib/diffc/encode.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from lib.diffc.utils.p import P 3 | from lib.diffc.utils.q import Q 4 | from tqdm import tqdm 5 | 6 | 7 | @torch.no_grad() 8 | def encode( 9 | target_latent, 10 | timestep_schedule, 11 | noise_prediction_model, 12 | gaussian_channel_simulator, 13 | manual_dkl_per_step=None, 14 | recon_timesteps=[], 15 | seed=0, 16 | ): 17 | """Creates a compressed representation of an image using a diffusion model. 18 | 19 | Args: 20 | target_latent: Latent representation of the image to encode, as produced by the 21 | diffusion model's VAE encoder. 22 | timestep_schedule: List of timesteps, parallel to SNR_schedule. The timesteps should match the SNRs that the diffusion model expects at those timesteps. 23 | SNR_schedule: List of signal-to-noise ratios, decreasing towards zero (e.g., 24 | [0.8, 0.6, 0.4, 0.2, 0.1]). SNR values must be in the set of values expected 25 | by the predict_noise function. Last element must be > 0. Ending with '0' 26 | (lossless compression of the latent) is not currently supported (and probably 27 | not desirable). 28 | predict_noise: Callable which takes in a noisy latent and that latent's SNR, and 29 | returns a prediction of the latent's noise component. 30 | gaussian_channel_simulator: Used for gaussian channel simulation. 31 | manual_dkl_per_step: Used to manually hard-code the dkl per step. Otherwise we'd 32 | need to send it as side-information. TODO: fancier entropy models of dkl per 33 | step? 34 | recon_timesteps: List of timesteps in decreasing order. When used, saves the noisy 35 | latents from the encoding process at each timestep. 36 | seed: 37 | random seed for the compression process. 38 | 39 | Returns: 40 | tuple: 41 | - chunk_seeds_per_step (List[List[int]]): One list of ints per step. This is 42 | the compressed representation of the image, although it still needs to be 43 | entropy coded. Fed back into the gaussian channel simulator for decoding. 44 | - dkl_per_step (List[float]): This is also fed back in to the gaussian 45 | channel simulator to reconstruct the denoising process. 46 | - noisy_recons: Noisy reconstructions of the target image generated during 47 | the encoding process. These will be the same noisy reconstructions 48 | generated during decoding. For faster evaluation, we can skip decoding and 49 | just use these recons. 50 | - noisy_recon_step_indices (List[float]): List which is parallel to 51 | noisy_recons, and reports the step index for each recon. 52 | """ 53 | chunk_seeds_per_step = [] 54 | dkl_per_step = [] 55 | noisy_recons = [] 56 | noisy_recon_step_indices = [] 57 | recon_timesteps = recon_timesteps.copy() 58 | 59 | torch.manual_seed(seed) 60 | noisy_latent = torch.randn( 61 | target_latent.shape, device=target_latent.device, dtype=target_latent.dtype 62 | ) 63 | 64 | current_timestep = 1000 65 | current_snr = noise_prediction_model.get_timestep_snr(current_timestep) 66 | 67 | for step_index, prev_timestep in tqdm( 68 | enumerate(timestep_schedule), total=len(timestep_schedule) 69 | ): # "previous" as in closer to 1 than the current snr 70 | noise_prediction = noise_prediction_model.predict_noise( 71 | noisy_latent, current_timestep 72 | ) 73 | prev_snr = noise_prediction_model.get_timestep_snr(prev_timestep) 74 | p_mu, std = P(noisy_latent, noise_prediction, current_snr, prev_snr) 75 | q_mu = Q(noisy_latent, target_latent, current_snr, prev_snr) 76 | q_mu_flat_normed = ((q_mu - p_mu) / std).flatten().detach().cpu().numpy() 77 | 78 | manual_dkl = ( 79 | None if manual_dkl_per_step is None else manual_dkl_per_step[step_index] 80 | ) 81 | 82 | sample, chunk_seeds, dkl = gaussian_channel_simulator.encode( 83 | q_mu_flat_normed, manual_dkl=manual_dkl, seed=step_index 84 | ) 85 | chunk_seeds_per_step.append(chunk_seeds) 86 | dkl_per_step.append(dkl) 87 | sample = torch.tensor(sample) 88 | reshaped_sample = ( 89 | sample.reshape(noisy_latent.shape) 90 | .to(noisy_latent.device) 91 | .to(noisy_latent.dtype) 92 | ) 93 | noisy_latent = reshaped_sample * std + p_mu 94 | current_timestep = prev_timestep 95 | current_snr = prev_snr 96 | 97 | ## Optionally, save the current reconstruction 98 | save_current_latent = False 99 | while len(recon_timesteps) > 0 and current_timestep <= recon_timesteps[0]: 100 | save_current_latent = True 101 | recon_timesteps = recon_timesteps[1:] 102 | 103 | if save_current_latent: 104 | noisy_recons.append(noisy_latent) 105 | noisy_recon_step_indices.append(step_index) 106 | 107 | return chunk_seeds_per_step, dkl_per_step, noisy_recons, noisy_recon_step_indices 108 | -------------------------------------------------------------------------------- /decompress.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | import yaml 4 | from easydict import EasyDict as edict 5 | import zlib 6 | import struct 7 | 8 | from lib import image_utils 9 | from lib.diffc.denoise import denoise 10 | from lib.diffc.decode import decode 11 | from lib.diffc.rcc.gaussian_channel_simulator import GaussianChannelSimulator 12 | 13 | def parse_args(): 14 | parser = argparse.ArgumentParser( 15 | description="Decompress DiffC-compressed images" 16 | ) 17 | parser.add_argument( 18 | "--config", 19 | help="Path to the compression config file", 20 | required=True 21 | ) 22 | parser.add_argument( 23 | "--input_path", 24 | default=None, 25 | help="Path to a single .diffc file to decompress" 26 | ) 27 | parser.add_argument( 28 | "--input_dir", 29 | default=None, 30 | help="Path to a directory containing .diffc files to decompress" 31 | ) 32 | parser.add_argument( 33 | "--output_dir", 34 | required=True, 35 | help="Directory to output the decompressed images to" 36 | ) 37 | return parser.parse_args() 38 | 39 | def get_noise_prediction_model(model_name, config): 40 | if model_name == "SD1.5": 41 | from lib.models.SD15 import SD15Model 42 | return SD15Model() 43 | elif model_name == "SD2.1": 44 | from lib.models.SD21 import SD21Model 45 | return SD21Model() 46 | elif model_name == "SDXL": 47 | from lib.models.SDXL import SDXLModel 48 | use_refiner = config.get("use_refiner", False) 49 | return SDXLModel(use_refiner=use_refiner) 50 | elif model_name == 'Flux': 51 | from lib.models.Flux import FluxModel 52 | return FluxModel() 53 | else: 54 | raise ValueError(f"Unrecognized model: {model_name}") 55 | 56 | def read_diffc_file(file_path): 57 | with open(file_path, 'rb') as f: 58 | # Read caption length (4 bytes) 59 | caption_length = struct.unpack('