├── .github └── workflows │ └── deploy.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .style.yapf ├── ACKNOWLEDGEMENT.md ├── CITATION.cff ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── diffusion ├── __init__.py ├── base.py ├── data.py ├── distribution.py ├── guidance.py ├── loss.py ├── net.py ├── noise.py ├── schedule.py ├── time.py └── utils │ └── nn │ ├── __init__.py │ └── functional.py ├── docs ├── .gitignore ├── astro.config.mjs ├── env.d.ts ├── package-lock.json ├── package.json ├── public │ ├── fonts │ │ ├── FiraCode │ │ │ ├── FiraCode-Bold.ttf │ │ │ ├── FiraCode-Light.ttf │ │ │ ├── FiraCode-Medium.ttf │ │ │ ├── FiraCode-Regular.ttf │ │ │ └── FiraCode-SemiBold.ttf │ │ └── Inter │ │ │ ├── Inter-Black.ttf │ │ │ ├── Inter-Bold.ttf │ │ │ ├── Inter-ExtraBold.ttf │ │ │ ├── Inter-ExtraLight.ttf │ │ │ ├── Inter-Light.ttf │ │ │ ├── Inter-Medium.ttf │ │ │ ├── Inter-Regular.ttf │ │ │ ├── Inter-SemiBold.ttf │ │ │ └── Inter-Thin.ttf │ ├── images │ │ ├── guides │ │ │ └── getting-started │ │ │ │ ├── conditional.png │ │ │ │ ├── unconditional-cosine.png │ │ │ │ └── unconditional-linear.png │ │ ├── logo.ico │ │ ├── logo.png │ │ ├── logo2.png │ │ ├── menu.svg │ │ └── modules │ │ │ ├── noise-schedule │ │ │ ├── constant.png │ │ │ ├── cosine.png │ │ │ ├── linear.png │ │ │ └── sqrt.png │ │ │ └── noise-type │ │ │ ├── absorbing.png │ │ │ ├── gaussian.png │ │ │ └── uniform.png │ └── robots.txt ├── src │ ├── components │ │ ├── Article.astro │ │ ├── Header.astro │ │ ├── Index.astro │ │ ├── Sidebar.astro │ │ └── markdown │ │ │ ├── Bullet.astro │ │ │ ├── Code.astro │ │ │ ├── Image.astro │ │ │ ├── Link.astro │ │ │ ├── Note.astro │ │ │ ├── Paragraph.astro │ │ │ ├── Subtitle.astro │ │ │ └── Title.astro │ ├── env.d.ts │ ├── layouts │ │ └── Layout.astro │ ├── pages │ │ ├── _drafts │ │ │ └── model.mdx │ │ ├── guides │ │ │ ├── custom-modules.mdx │ │ │ ├── getting-started.mdx │ │ │ ├── image-generation.mdx │ │ │ └── text-generation.mdx │ │ └── modules │ │ │ ├── data-transform.mdx │ │ │ ├── denoising-network.mdx │ │ │ ├── diffusion-model.mdx │ │ │ ├── guidance.mdx │ │ │ ├── loss-function.mdx │ │ │ ├── noise-schedule.mdx │ │ │ ├── noise-type.mdx │ │ │ └── probability-distribution.mdx │ ├── plugins │ │ └── remark-layout.mjs │ └── styles │ │ └── fonts.css ├── tailwind.config.cjs └── tsconfig.json ├── examples ├── conditional-diffusion.py ├── data │ └── representative │ │ ├── in │ │ ├── afhq │ │ │ ├── flickr_dog_000083.jpg │ │ │ ├── flickr_dog_001159.jpg │ │ │ ├── pixabay_dog_000802.jpg │ │ │ ├── pixabay_dog_003974.jpg │ │ │ └── pixabay_dog_004034.jpg │ │ ├── e2e │ │ │ ├── 0.txt │ │ │ ├── 1.txt │ │ │ ├── 2.txt │ │ │ ├── 3.txt │ │ │ └── 4.txt │ │ └── mnist │ │ │ ├── 1.png │ │ │ ├── 2.png │ │ │ ├── 3.png │ │ │ ├── 6.png │ │ │ └── 7.png │ │ └── out │ │ ├── conditional-diffusion │ │ ├── 1-7.31e-02.png │ │ ├── 2-3.91e-02.png │ │ ├── 20-3.56e-02.png │ │ ├── 4-4.54e-02.png │ │ └── 7-4.03e-02.png │ │ ├── embedding-diffusion │ │ ├── 293 (4.36e-03).txt │ │ ├── 490 (2.41e-02).txt │ │ ├── 5 (1.63e-01).txt │ │ ├── 51 (5.72e-02).txt │ │ └── 960 (1.87e-02).txt │ │ ├── transformer-diffusion │ │ ├── 26-2.42e-02.png │ │ ├── 3650-7.06e-03.png │ │ ├── 4986-7.32e-03.png │ │ ├── 5002-6.01e-03.png │ │ └── 509-1.29e-01.png │ │ └── unconditional-diffusion │ │ ├── 18-3.67e-02.png │ │ ├── 2-5.50e-02.png │ │ ├── 36-3.63e-02.png │ │ ├── 42-3.40e-02.png │ │ └── 51-3.40e-02.png ├── embedding-diffusion.py ├── text-diffusion.py ├── transformer-diffusion.py ├── unconditional-diffusion.py └── utils │ └── __init__.py ├── pyproject.toml ├── pyrightconfig.json └── requirements.txt /.github/workflows/deploy.yml: -------------------------------------------------------------------------------- 1 | name: Deploy to GitHub Pages 2 | 3 | on: 4 | # Trigger the workflow every time you push to the `main` branch 5 | # Using a different branch name? Replace `main` with your branch’s name 6 | push: 7 | branches: [main] 8 | # Allows you to run this workflow manually from the Actions tab on GitHub. 9 | workflow_dispatch: 10 | 11 | # Allow this job to clone the repo and create a page deployment 12 | permissions: 13 | contents: read 14 | pages: write 15 | id-token: write 16 | 17 | jobs: 18 | build: 19 | runs-on: ubuntu-latest 20 | steps: 21 | - name: Checkout your repository using git 22 | uses: actions/checkout@v3 23 | - name: Install, build, and upload your site output 24 | uses: withastro/action@v0 25 | with: 26 | path: docs # The root location of your Astro project inside the repository. (optional) 27 | # node-version: 16 # The specific version of Node that should be used to build your site. Defaults to 16. (optional) 28 | # package-manager: yarn # The Node package manager that should be used to install dependencies and build your site. Automatically detected based on your lockfile. (optional) 29 | # resolve-dep-from-path: false # If the dependency file should be resolved from the root location of your Astro project. Defaults to `true`. (optional) 30 | 31 | deploy: 32 | needs: build 33 | runs-on: ubuntu-latest 34 | environment: 35 | name: github-pages 36 | url: ${{ steps.deployment.outputs.page_url }} 37 | steps: 38 | - name: Deploy to GitHub Pages 39 | id: deployment 40 | uses: actions/deploy-pages@v1 41 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .pytest_cache 2 | .vscode 3 | __pycache__ 4 | env 5 | nohup.out 6 | private 7 | examples/data/in 8 | examples/data/out 9 | dist -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.4.0 4 | hooks: 5 | - id: requirements-txt-fixer 6 | - repo: https://github.com/pycqa/isort 7 | rev: 5.12.0 8 | hooks: 9 | - id: isort 10 | args: [-l=88, -m=3, --tc] 11 | - repo: https://github.com/google/yapf 12 | rev: v0.33.0 13 | hooks: 14 | - id: yapf 15 | name: yapf 16 | language: python 17 | entry: yapf 18 | args: [-i, -vv] 19 | types: [python] 20 | -------------------------------------------------------------------------------- /.style.yapf: -------------------------------------------------------------------------------- 1 | [style] 2 | based_on_style = google 3 | column_limit = 88 4 | -------------------------------------------------------------------------------- /ACKNOWLEDGEMENT.md: -------------------------------------------------------------------------------- 1 | # Acknowledgement 2 | 3 | This work was partly funded by: project SmartEDU (CENTRO-01-0247-FEDER-072620), co-financed by FEDER, through PT2020, and by the Regional Operational Programme Centro 2020; and national funds 4 | through the FCT – Foundation for Science and Technology, I.P., within the scope of the project CISUC – UID/CEC/00326/2020 and by the European Social Fund, through the Regional Operational 5 | Program Centro 2020. 6 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | authors: 2 | - family-names: Cabral Pinto 3 | given-names: João 4 | cff-version: 1.2.0 5 | message: "If you use this library, please cite it using these metadata." 6 | title: "Modular Diffusion" 7 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | We as members, contributors, and leaders pledge to make participation in our 6 | community a harassment-free experience for everyone, regardless of age, body 7 | size, visible or invisible disability, ethnicity, sex characteristics, gender 8 | identity and expression, level of experience, education, socio-economic status, 9 | nationality, personal appearance, race, religion, or sexual identity 10 | and orientation. 11 | 12 | We pledge to act and interact in ways that contribute to an open, welcoming, 13 | diverse, inclusive, and healthy community. 14 | 15 | ## Our Standards 16 | 17 | Examples of behavior that contributes to a positive environment for our 18 | community include: 19 | 20 | * Demonstrating empathy and kindness toward other people 21 | * Being respectful of differing opinions, viewpoints, and experiences 22 | * Giving and gracefully accepting constructive feedback 23 | * Accepting responsibility and apologizing to those affected by our mistakes, 24 | and learning from the experience 25 | * Focusing on what is best not just for us as individuals, but for the 26 | overall community 27 | 28 | Examples of unacceptable behavior include: 29 | 30 | * The use of sexualized language or imagery, and sexual attention or 31 | advances of any kind 32 | * Trolling, insulting or derogatory comments, and personal or political attacks 33 | * Public or private harassment 34 | * Publishing others' private information, such as a physical or email 35 | address, without their explicit permission 36 | * Other conduct which could reasonably be considered inappropriate in a 37 | professional setting 38 | 39 | ## Enforcement Responsibilities 40 | 41 | Community leaders are responsible for clarifying and enforcing our standards of 42 | acceptable behavior and will take appropriate and fair corrective action in 43 | response to any behavior that they deem inappropriate, threatening, offensive, 44 | or harmful. 45 | 46 | Community leaders have the right and responsibility to remove, edit, or reject 47 | comments, commits, code, wiki edits, issues, and other contributions that are 48 | not aligned to this Code of Conduct, and will communicate reasons for moderation 49 | decisions when appropriate. 50 | 51 | ## Scope 52 | 53 | This Code of Conduct applies within all community spaces, and also applies when 54 | an individual is officially representing the community in public spaces. 55 | Examples of representing our community include using an official e-mail address, 56 | posting via an official social media account, or acting as an appointed 57 | representative at an online or offline event. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported to the community leaders responsible for enforcement at 63 | jmcabralpinto@gmail.com. 64 | All complaints will be reviewed and investigated promptly and fairly. 65 | 66 | All community leaders are obligated to respect the privacy and security of the 67 | reporter of any incident. 68 | 69 | ## Enforcement Guidelines 70 | 71 | Community leaders will follow these Community Impact Guidelines in determining 72 | the consequences for any action they deem in violation of this Code of Conduct: 73 | 74 | ### 1. Correction 75 | 76 | **Community Impact**: Use of inappropriate language or other behavior deemed 77 | unprofessional or unwelcome in the community. 78 | 79 | **Consequence**: A private, written warning from community leaders, providing 80 | clarity around the nature of the violation and an explanation of why the 81 | behavior was inappropriate. A public apology may be requested. 82 | 83 | ### 2. Warning 84 | 85 | **Community Impact**: A violation through a single incident or series 86 | of actions. 87 | 88 | **Consequence**: A warning with consequences for continued behavior. No 89 | interaction with the people involved, including unsolicited interaction with 90 | those enforcing the Code of Conduct, for a specified period of time. This 91 | includes avoiding interactions in community spaces as well as external channels 92 | like social media. Violating these terms may lead to a temporary or 93 | permanent ban. 94 | 95 | ### 3. Temporary Ban 96 | 97 | **Community Impact**: A serious violation of community standards, including 98 | sustained inappropriate behavior. 99 | 100 | **Consequence**: A temporary ban from any sort of interaction or public 101 | communication with the community for a specified period of time. No public or 102 | private interaction with the people involved, including unsolicited interaction 103 | with those enforcing the Code of Conduct, is allowed during this period. 104 | Violating these terms may lead to a permanent ban. 105 | 106 | ### 4. Permanent Ban 107 | 108 | **Community Impact**: Demonstrating a pattern of violation of community 109 | standards, including sustained inappropriate behavior, harassment of an 110 | individual, or aggression toward or disparagement of classes of individuals. 111 | 112 | **Consequence**: A permanent ban from any sort of public interaction within 113 | the community. 114 | 115 | ## Attribution 116 | 117 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 118 | version 2.0, available at 119 | https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. 120 | 121 | Community Impact Guidelines were inspired by [Mozilla's code of conduct 122 | enforcement ladder](https://github.com/mozilla/diversity). 123 | 124 | [homepage]: https://www.contributor-covenant.org 125 | 126 | For answers to common questions about this code of conduct, see the FAQ at 127 | https://www.contributor-covenant.org/faq. Translations are available at 128 | https://www.contributor-covenant.org/translations. 129 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to Modular Diffusion 2 | 3 | Thank you for your interest in contributing to Modular Diffusion! We welcome your support and contributions to help us improve and expand this framework. Whether you want to fix a bug, enhance an existing feature, or introduce a new prebuilt module, your contributions are valuable to us. 4 | 5 | ## How to Contribute 6 | 7 | ### Reporting Issues 8 | 9 | If you encounter a bug, have a suggestion, or find a typo in the documentation, please feel free to [open an issue](https://github.com/cabralpinto/modular-diffusion/issues) on our GitHub repository. When reporting an issue, please provide as much detail as possible, including: 10 | 11 | - A clear and concise description of the issue or enhancement request. 12 | - Steps to reproduce the issue, if applicable. 13 | - Information about your environment, such as Python version, PyTorch version, and any other relevant details. 14 | 15 | ### Adding New Features and Modules 16 | 17 | If you want to contribute a new feature to Modular Diffusion, we recommend starting by opening an issue to discuss your idea with the maintainers. This way, we can ensure that the module aligns with the project's goals and design. If we give you the green light, you're free to submit a pull request. 18 | 19 | ### Pull Requests 20 | 21 | We encourage you to submit pull requests to address issues or contribute new features. Here's how to do it: 22 | 23 | 1. Fork the [Modular Diffusion GitHub repository](https://github.com/cabralpinto/modular-diffusion) to your own GitHub account. 24 | 25 | 2. Clone your forked repository to your local machine: 26 | 27 | ```bash 28 | git clone https://github.com/your-username/modular-diffusion.git 29 | ``` 30 | 31 | 3. Create a new branch for your changes: 32 | 33 | ```bash 34 | git checkout -b your-feature-name 35 | ``` 36 | 37 | 4. Make your changes, attempting to follow the same code style of the project. 38 | 39 | 5. Commit your changes with clear and descriptive commit messages: 40 | 41 | ```bash 42 | git commit -m "Add new feature" 43 | ``` 44 | 45 | 6. Push your changes to your GitHub repository: 46 | 47 | ```bash 48 | git push origin your-feature-name 49 | ``` 50 | 51 | 7. Open a pull request on the [Modular Diffusion GitHub repository](https://github.com/cabralpinto/modular-diffusion) from your branch to the `main` branch. Please provide a detailed description of your changes in the pull request. 52 | 53 | 8. Your pull request will be reviewed, and any necessary feedback or changes will be discussed. 54 | 55 | 9. Once your pull request is approved, it will be merged into the main codebase, and your contribution will be acknowledged. 56 | 57 | ## License 58 | 59 | By contributing to Modular Diffusion, you agree that your contributions will be licensed under the [MIT License](LICENSE). This ensures that your contributions can be freely used, modified, and distributed by others. 60 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 João Cabral Pinto 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Modular Diffusion 2 | 3 | [![PyPI version](https://badge.fury.io/py/modular-diffusion.svg)](https://badge.fury.io/py/modular-diffusion) 4 | [![Documentation](https://img.shields.io/badge/docs-stable-blue.svg)](https://cabralpinto.github.io/modular-diffusion/) 5 | [![MIT license](https://img.shields.io/badge/license-MIT-blue.svg)](https://lbesson.mit-license.org/) 6 | [![Discord](https://dcbadge.vercel.app/api/server/mYJWQATfTV?style=flat&compact=true)](https://discord.gg/mYJWQATfTV) 7 | 8 | Modular Diffusion provides an easy-to-use modular API to design and train custom Diffusion Models with PyTorch. Whether you're an enthusiast exploring Diffusion Models or a hardcore ML researcher, **this framework is for you**. 9 | 10 | ## Features 11 | 12 | - ⚙️ **Highly Modular Design**: Effortlessly swap different components of the diffusion process, including noise type, schedule type, denoising network, and loss function. 13 | - 📚 **Growing Library of Pre-built Modules**: Get started right away with our comprehensive selection of pre-built modules. 14 | - 🔨 **Custom Module Creation Made Easy**: Craft your own original modules by inheriting from a base class and implementing the required methods. 15 | - 🤝 **Integration with PyTorch**: Built on top of PyTorch, Modular Diffusion enables you to develop custom modules using a familiar syntax. 16 | - 🌈 **Broad Range of Applications**: From generating high-quality images to implementing non-autoregressive text synthesis pipelines, the possiblities are endless. 17 | 18 | ## Installation 19 | 20 | Modular Diffusion officially supports Python 3.10+ and is available on PyPI: 21 | 22 | ```bash 23 | pip install modular-diffusion 24 | ``` 25 | 26 | You also need to install the correct [PyTorch distribution](https://pytorch.org/get-started/locally/) for your system. 27 | 28 | > **Note**: Although Modular Diffusion works with later Python versions, we currently recommend using Python 3.10. This is because `torch.compile`, which significantly improves the speed of the models, is not currently available for versions above Python 3.10. 29 | 30 | ## Usage 31 | 32 | With Modular Diffusion, you can build and train a custom Diffusion Model in just a few lines. First, load and normalize your dataset. We are using the dog pictures from [AFHQ](https://paperswithcode.com/dataset/afhq). 33 | 34 | ```python 35 | x, _ = zip(*ImageFolder("afhq", ToTensor())) 36 | x = resize(x, [h, w], antialias=False) 37 | x = torch.stack(x) * 2 - 1 38 | ``` 39 | 40 | Next, build your custom model using either Modular Diffusion's prebuilt modules or [your custom modules](https://cabralpinto.github.io/modular-diffusion/guides/custom-modules/). 41 | 42 | ```python 43 | model = diffusion.Model( 44 | data=Identity(x, batch=128, shuffle=True), 45 | schedule=Cosine(steps=1000), 46 | noise=Gaussian(parameter="epsilon", variance="fixed"), 47 | net=UNet(channels=(1, 64, 128, 256)), 48 | loss=Simple(parameter="epsilon"), 49 | ) 50 | ``` 51 | 52 | Now, train and sample from the model. 53 | 54 | ```python 55 | losses = [*model.train(epochs=400)] 56 | z = model.sample(batch=10) 57 | z = z[torch.linspace(0, z.shape[0] - 1, 10).long()] 58 | z = rearrange(z, "t b c h w -> c (b h) (t w)") 59 | save_image((z + 1) / 2, "output.png") 60 | ``` 61 | 62 | Finally, marvel at the results. 63 | 64 | Modular Diffusion teaser  65 | 66 | Check out the [Getting Started Guide](https://cabralpinto.github.io/modular-diffusion/guides/getting-started/) to learn more and find more examples [here](https://github.com/cabralpinto/modular-diffusion/tree/main/examples). 67 | 68 | ## Contributing 69 | 70 | We appreciate your support and welcome your contributions! Please feel free to submit pull requests if you found a bug or typo you want to fix. If you want to contribute a new prebuilt module or feature, please start by opening an issue and discussing it with us. If you don't know where to begin, take a look at the [open issues](https://github.com/cabralpinto/modular-diffusion/issues). Please read our [Contributing Guide](https://github.com/cabralpinto/modular-diffusion/blob/main/CONTRIBUTING.md) for more details. 71 | 72 | ## License 73 | 74 | This project is licensed under the [MIT License](LICENSE). 75 | -------------------------------------------------------------------------------- /diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from dataclasses import dataclass, field 3 | from itertools import chain 4 | from pathlib import Path 5 | from typing import Callable, Generic, Iterator, Optional, TypeVar 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch import Tensor 10 | from torch.optim import Adam, Optimizer 11 | from tqdm import tqdm 12 | 13 | from . import data, guidance, loss, net, noise, schedule, time 14 | from .base import Batch, Data, Distribution, Guidance, Loss, Net, Noise, Schedule, Time 15 | from .time import Discrete 16 | 17 | __all__ = ["data", "loss", "net", "noise", "schedule", "time", "Model"] 18 | 19 | D = TypeVar("D", bound=Distribution, covariant=True) 20 | 21 | 22 | @dataclass 23 | class Model(Generic[D]): 24 | data: Data 25 | schedule: Schedule 26 | noise: Noise[D] 27 | net: Net 28 | loss: Loss[D] 29 | time: Time = field(default_factory=Discrete) 30 | guidance: Optional[Guidance] = None # TODO remove hardcoding 31 | optimizer: Optional[Optimizer | Callable[..., Optimizer]] = None 32 | device: str | torch.device = torch.device("cpu") 33 | compile: bool = True 34 | 35 | @torch.no_grad() 36 | def __post_init__(self): 37 | self.noise.schedule(self.schedule.compute().to(self.device)) 38 | parameters = chain(self.data.parameters(), self.net.parameters()) 39 | if self.optimizer is None: 40 | self.optimizer = Adam(parameters, lr=1e-4) 41 | elif callable(self.optimizer): 42 | self.optimizer = self.optimizer(parameters) 43 | self.net = self.net.to(self.device) 44 | for name, value in vars(self.data).items(): 45 | if isinstance(value, nn.Module): 46 | setattr(self.data, name, value.to(self.device)) 47 | if self.compile and sys.version_info < (3, 11): 48 | self.net = torch.compile(self.net) # type: ignore[union-attr] 49 | if isinstance(self.device, str): 50 | self.device = torch.device(self.device) 51 | 52 | @torch.enable_grad() 53 | def train(self, epochs: int = 1, progress: bool = True) -> Iterator[float]: 54 | self.net.train() 55 | batch = Batch[D](self.device) # type: ignore[union-attr] 56 | for _ in range(epochs): 57 | bar = tqdm(self.data, disable=not progress) 58 | for batch.w, batch.y in self.data: 59 | if isinstance(self.guidance, guidance.ClassifierFree): 60 | i = torch.randperm(batch.y.shape[0]) 61 | batch.y[i[:int(batch.y.shape[0] * self.guidance.dropout)]] = 0 62 | batch.x = self.data.encode(batch.w) 63 | batch.t = self.time.sample(self.schedule.steps, batch.x.shape[0]) 64 | batch.z, batch.epsilon = self.noise.prior(batch.x, batch.t).sample() 65 | batch.hat = self.net(batch.z, batch.y, batch.t) 66 | batch.q = self.noise.posterior(batch.x, batch.z, batch.t) 67 | batch.p = self.noise.approximate(batch.z, batch.t, batch.hat) 68 | batch.l = self.loss.compute(batch) 69 | self.optimizer.zero_grad() # type: ignore[union-attr] 70 | batch.l.backward() 71 | self.optimizer.step() # type: ignore[union-attr] 72 | bar.set_postfix(loss=f"{batch.l.item():.2e}") 73 | bar.update() 74 | bar.close() 75 | yield batch.l.item() 76 | 77 | @torch.no_grad() 78 | def sample( 79 | self, 80 | y: Optional[Tensor] = None, 81 | batch: int = 1, 82 | progress: bool = True, 83 | ) -> Tensor: 84 | self.net.eval() 85 | if y is None: 86 | shape = 1, *(() if self.data.y is None else self.data.y.shape[1:]) 87 | y = torch.zeros(shape, dtype=torch.int, device=self.device) 88 | y = y.repeat_interleave(batch, 0).to(self.device) 89 | pi = self.noise.stationary((y.shape[0], *self.data.shape)) 90 | z = pi.sample()[0].to(self.device) 91 | l = self.data.decode(z)[None] 92 | bar = tqdm(total=self.schedule.steps, disable=not progress) 93 | for t in range(self.schedule.steps, 0, -1): 94 | t = torch.full((batch,), t, device=self.device) 95 | hat = self.net(z, y, t) 96 | if isinstance(self.guidance, guidance.ClassifierFree): 97 | s = self.guidance.strength 98 | hat = (1 + s) * hat - s * self.net(z, torch.zeros_like(y), t) 99 | z, _ = self.noise.approximate(z, t, hat).sample() 100 | w = self.data.decode(z) 101 | l = torch.cat((l, w[None]), 0) 102 | bar.update() 103 | bar.close() 104 | return l 105 | 106 | @torch.no_grad() 107 | def load(self, path: Path | str): 108 | state = torch.load(path) 109 | self.net.load_state_dict(state["net"]) 110 | for name, dict in state["data"].items(): 111 | getattr(self.data, name).load_state_dict(dict) 112 | 113 | @torch.no_grad() 114 | def save(self, path: Path | str): 115 | state = { 116 | "net": self.net.state_dict(), 117 | "data": { 118 | name: value.state_dict() 119 | for name, value in vars(self.data).items() 120 | if isinstance(value, nn.Module) 121 | } 122 | } 123 | torch.save(state, path) 124 | -------------------------------------------------------------------------------- /diffusion/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from dataclasses import dataclass, field 3 | from itertools import chain 4 | from typing import Any, Callable, Generic, Iterator, Optional, TypeVar 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch import Tensor 9 | from typing_extensions import Self 10 | 11 | from .utils.nn import Sequential 12 | 13 | __all__ = ["Batch", "Data", "Distribution", "Loss", "Net", "Noise", "Schedule", "Time"] 14 | 15 | 16 | class Distribution(ABC): 17 | 18 | @abstractmethod 19 | def sample(self) -> tuple[Tensor, Optional[Tensor]]: 20 | raise NotImplementedError 21 | 22 | @abstractmethod 23 | def nll(self, x: Tensor) -> Tensor: 24 | raise NotImplementedError 25 | 26 | @abstractmethod 27 | def dkl(self, other: Self) -> Tensor: 28 | raise NotImplementedError 29 | 30 | 31 | D = TypeVar("D", bound=Distribution, covariant=True) 32 | 33 | 34 | @dataclass 35 | class Batch(Generic[D]): 36 | device: torch.device 37 | w: Tensor = field(init=False) 38 | x: Tensor = field(init=False) 39 | y: Tensor = field(init=False) 40 | t: Tensor = field(init=False) 41 | epsilon: Optional[Tensor] = field(init=False) 42 | z: Tensor = field(init=False) 43 | hat: Tensor = field(init=False) 44 | q: D = field(init=False) 45 | p: D = field(init=False) 46 | l: Tensor = field(init=False) 47 | 48 | def __setattr__(self, prop: str, val: Any): 49 | if isinstance(val, Tensor): 50 | if hasattr(self, prop): 51 | del self.__dict__[prop] 52 | val = val.to(self.device) 53 | super().__setattr__(prop, val) 54 | 55 | 56 | @dataclass 57 | class Data(ABC): 58 | w: Tensor 59 | y: Optional[Tensor] = None 60 | batch: int = 1 61 | shuffle: bool = False 62 | 63 | @property 64 | def shape(self) -> tuple[int, ...]: 65 | return self.encode(self.w[:1]).shape[1:] 66 | 67 | def parameters(self) -> Iterator[nn.Parameter]: 68 | return chain.from_iterable( 69 | var.parameters() for var in vars(self) if isinstance(var, nn.Module)) 70 | 71 | def __iter__(self) -> Iterator[tuple[Tensor, Tensor]]: 72 | if self.y is None: 73 | self.y = torch.zeros(self.w.shape[0], dtype=torch.int) 74 | if self.shuffle: 75 | index = torch.randperm(self.w.shape[0]) 76 | self.w, self.y = self.w[index], self.y[index] 77 | self.data = zip(self.w.split(self.batch), self.y.split(self.batch)) 78 | return self 79 | 80 | def __next__(self) -> tuple[Tensor, Tensor]: 81 | return next(self.data) 82 | 83 | def __len__(self) -> int: 84 | return -(self.w.shape[0] // -self.batch) 85 | 86 | @abstractmethod 87 | def encode(self, w: Tensor) -> Tensor: 88 | raise NotImplementedError 89 | 90 | @abstractmethod 91 | def decode(self, x: Tensor) -> Tensor: 92 | raise NotImplementedError 93 | 94 | 95 | class Time(ABC): 96 | 97 | @abstractmethod 98 | def sample(self, steps: int, size: int) -> Tensor: 99 | raise NotImplementedError 100 | 101 | 102 | @dataclass 103 | class Schedule(ABC): 104 | steps: int 105 | 106 | @abstractmethod 107 | def compute(self) -> Tensor: 108 | """Compute the diffusion schedule alpha_t for t = 0, ..., T""" 109 | raise NotImplementedError 110 | 111 | 112 | @dataclass 113 | class Noise(ABC, Generic[D]): 114 | 115 | @abstractmethod 116 | def schedule(self, alpha: Tensor) -> None: 117 | """Precompute needed resources based on the diffusion schedule""" 118 | raise NotImplementedError 119 | 120 | @abstractmethod 121 | def stationary(self, shape: tuple[int, ...]) -> D: 122 | """Compute the stationary distribution q(x_T)""" 123 | raise NotImplementedError 124 | 125 | @abstractmethod 126 | def prior(self, x: Tensor, t: Tensor) -> D: 127 | """Compute the prior distribution q(x_t | x_0)""" 128 | raise NotImplementedError 129 | 130 | @abstractmethod 131 | def posterior(self, x: Tensor, z: Tensor, t: Tensor) -> D: 132 | """Compute the posterior distribution q(x_{t-1} | x_t, x_0)""" 133 | raise NotImplementedError 134 | 135 | @abstractmethod 136 | def approximate(self, z: Tensor, t: Tensor, hat: Tensor) -> D: 137 | """Compute the approximate posterior distribution p(x_{t-1} | x_t)""" 138 | raise NotImplementedError 139 | 140 | 141 | class Net(ABC, nn.Module): 142 | __call__: Callable[[Tensor, Tensor, Tensor], Tensor] 143 | 144 | @abstractmethod 145 | def forward(self, x: Tensor, y: Tensor, t: Tensor) -> Tensor: 146 | raise NotImplementedError 147 | 148 | def __or__(self, module: nn.Module) -> "Net": 149 | return Sequential(self, module) # type: ignore 150 | 151 | 152 | class Guidance(ABC): 153 | pass 154 | 155 | 156 | class Loss(ABC, Generic[D]): 157 | 158 | @abstractmethod 159 | def compute(self, batch: Batch[D]) -> Tensor: 160 | raise NotImplementedError 161 | 162 | def __mul__(self, factor: float) -> "Mul[D]": 163 | return Mul(factor, self) 164 | 165 | def __rmul__(self, factor: float) -> "Mul[D]": 166 | return Mul(factor, self) 167 | 168 | def __truediv__(self, divisor: float) -> "Mul[D]": 169 | return Mul(1 / divisor, self) 170 | 171 | def __add__(self, other: "Loss[D]") -> "Add[D]": 172 | return Add(self, other) 173 | 174 | def __sub__(self, other: "Loss[D]") -> "Add[D]": 175 | return Add(self, Mul(-1, other)) 176 | 177 | 178 | @dataclass 179 | class Mul(Loss[D]): 180 | factor: float 181 | loss: Loss[D] 182 | 183 | def compute(self, batch: Batch[D]) -> Tensor: 184 | return self.factor * self.loss.compute(batch) 185 | 186 | 187 | class Add(Loss[D]): 188 | 189 | def __init__(self, *losses: Loss[D]) -> None: 190 | self.losses = losses 191 | 192 | def compute(self, batch: Batch[D]) -> Tensor: 193 | sum = self.losses[0].compute(batch) 194 | for loss in self.losses[1:]: 195 | sum += loss.compute(batch) 196 | return sum 197 | -------------------------------------------------------------------------------- /diffusion/data.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional 3 | 4 | import torch 5 | from torch import Tensor, nn 6 | 7 | from .base import Data 8 | 9 | 10 | class Identity(Data): 11 | 12 | def encode(self, w: Tensor) -> Tensor: 13 | return w 14 | 15 | def decode(self, x: Tensor) -> Tensor: 16 | return x 17 | 18 | 19 | @dataclass 20 | class OneHot(Data): 21 | k: Optional[int] = None 22 | 23 | def __post_init__(self): 24 | assert self.k is not None 25 | self.i = torch.eye(self.k) 26 | 27 | def encode(self, w: Tensor) -> Tensor: 28 | self.i = self.i.to(w.device) 29 | return self.i[w] 30 | 31 | def decode(self, x: Tensor) -> Tensor: 32 | return x.argmax(-1) 33 | 34 | 35 | @dataclass 36 | class Embedding(Data): 37 | k: Optional[int] = None 38 | d: Optional[int] = None 39 | 40 | def __post_init__(self) -> None: 41 | assert self.k is not None and self.d is not None 42 | self.embedding = nn.Embedding(self.k, self.d) 43 | 44 | def encode(self, w: Tensor) -> Tensor: 45 | return self.embedding(w) 46 | 47 | def decode(self, x: Tensor) -> Tensor: 48 | return torch.cdist(x, self.embedding.weight).argmin(-1) -------------------------------------------------------------------------------- /diffusion/distribution.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch 4 | from torch import Tensor 5 | from typing_extensions import Self 6 | 7 | from .base import Distribution 8 | 9 | 10 | @dataclass 11 | class Normal(Distribution): 12 | mu: Tensor 13 | sigma: Tensor 14 | 15 | def sample(self) -> tuple[Tensor, Tensor]: 16 | epsilon = torch.randn(self.mu.shape, device=self.mu.device) 17 | return self.mu + self.sigma * epsilon, epsilon 18 | 19 | def nll(self, x: Tensor) -> Tensor: 20 | return (0.5 * ((x - self.mu) / self.sigma)**2 + 21 | (self.sigma * 2.5066282746310002).log()) 22 | 23 | def dkl(self, other: Self) -> Tensor: 24 | return (torch.log(other.sigma / self.sigma) + 25 | (self.sigma**2 + (self.mu - other.mu)**2) / (2 * other.sigma**2) - 0.5) 26 | 27 | 28 | @dataclass 29 | class Categorical(Distribution): 30 | p: Tensor 31 | 32 | def __post_init__(self) -> None: 33 | self.k = self.p.shape[-1] 34 | self.i = torch.eye(self.k, device=self.p.device) 35 | 36 | def sample(self) -> tuple[Tensor, None]: 37 | c = torch.multinomial(self.p.view(-1, self.k), 1, True) 38 | return self.i[c.view(*self.p.shape[:-1])], None 39 | 40 | def nll(self, x: Tensor) -> Tensor: 41 | return -((self.p * x).sum(-1) + 1e-6).log() 42 | 43 | def dkl(self, other: Self) -> Tensor: 44 | p1, p2 = self.p + 1e-6, other.p + 1e-6 45 | return (p1 * (p1.log() - p2.log())).sum(-1) 46 | -------------------------------------------------------------------------------- /diffusion/guidance.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from .base import Guidance 4 | 5 | __all__ = ["ClassifierFree"] 6 | 7 | 8 | @dataclass 9 | class ClassifierFree(Guidance): 10 | dropout: float 11 | strength: float 12 | -------------------------------------------------------------------------------- /diffusion/loss.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Callable, Literal, TypeVar 3 | 4 | import torch 5 | from torch import Tensor 6 | 7 | from .base import Batch, Distribution, Loss 8 | 9 | __all__ = ["Lambda", "Simple", "VLB"] 10 | 11 | D = TypeVar("D", bound=Distribution) 12 | 13 | 14 | @dataclass 15 | class Lambda(Loss[D]): 16 | function: Callable[[Batch[D]], Tensor] 17 | 18 | def compute(self, batch: Batch[D]) -> Tensor: 19 | return self.function(batch) 20 | 21 | 22 | @dataclass 23 | class Simple(Loss[Distribution]): 24 | parameter: Literal["x", "epsilon"] = "x" 25 | index = 0 26 | 27 | def compute(self, batch: Batch[Distribution]) -> Tensor: 28 | return torch.mean((getattr(batch, self.parameter) - batch.hat[self.index])**2) 29 | 30 | 31 | class VLB(Loss[Distribution]): 32 | 33 | def compute(self, batch: Batch[Distribution]) -> Tensor: 34 | t = batch.t.view(-1, *(1,) * (batch.x.ndim - 1)) 35 | return batch.q.dkl(batch.p).where(t > 1, batch.p.nll(batch.x)).mean() 36 | -------------------------------------------------------------------------------- /diffusion/net.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | 3 | import torch 4 | from einops.layers.torch import Rearrange 5 | from torch import Tensor 6 | from torch.nn import ( 7 | Conv2d, 8 | Embedding, 9 | GroupNorm, 10 | Identity, 11 | LayerNorm, 12 | Linear, 13 | Module, 14 | ModuleList, 15 | MultiheadAttention, 16 | Sequential, 17 | SiLU, 18 | Upsample, 19 | ) 20 | from torch.nn.functional import pad, silu 21 | from torchvision.transforms.functional import crop 22 | 23 | from .base import Net 24 | from .utils.nn import FastGELU, SinusoidalPositionalEmbedding, WeightStdConv2d, Swish 25 | 26 | __all__ = ["UNet", "Transformer"] 27 | 28 | 29 | class UNet(Net): 30 | 31 | class Block(Module): 32 | 33 | def __init__(self, hidden: int, channels: tuple[int, int], groups: int) -> None: 34 | super().__init__() 35 | self.linear = Linear(hidden, 2 * channels[1]) 36 | self.conv1 = WeightStdConv2d(*channels, 3, 1, 1) 37 | self.norm1 = GroupNorm(groups, channels[1]) 38 | self.conv2 = WeightStdConv2d(channels[1], channels[1], 3, 1, 1) 39 | self.norm2 = GroupNorm(groups, channels[1]) 40 | 41 | def forward(self, x: Tensor, c: Tensor) -> Tensor: 42 | a, b = torch.chunk(silu(self.linear(c))[..., None, None], 2, 1) 43 | x = silu(self.norm1(self.conv1(x)) * (a + 1) + b) 44 | x = silu(self.norm2(self.conv2(x))) 45 | return x 46 | 47 | class Attention(Module): 48 | 49 | def __init__(self, hidden: int, heads: int) -> None: 50 | super().__init__() 51 | self.attention = MultiheadAttention(hidden, heads) 52 | 53 | def forward(self, x: Tensor) -> Tensor: 54 | shape = x.shape 55 | x = x.flatten(2).permute(2, 0, 1) 56 | x, _ = self.attention(x, x, x, need_weights=False) 57 | x = x.permute(1, 2, 0).reshape(shape) 58 | return x 59 | 60 | def __init__( 61 | self, 62 | channels: Sequence[int], 63 | labels: int = 0, 64 | parameters: int = 1, 65 | hidden: int = 256, 66 | heads: int = 8, 67 | groups: int = 16, 68 | ) -> None: 69 | super().__init__() 70 | self.label = Embedding(labels + 1, hidden) 71 | self.time = SinusoidalPositionalEmbedding(hidden) 72 | self.input = Conv2d(channels[0], channels[1], 3, 1, 1) 73 | self.encoder = ModuleList([ 74 | ModuleList([ 75 | UNet.Block(hidden, 2 * (channels_[0],), groups), 76 | UNet.Block(hidden, 2 * (channels_[0],), groups), 77 | UNet.Attention(channels_[0], heads), 78 | GroupNorm(groups, channels_[0]), 79 | Sequential( 80 | Rearrange("b c (h p1) (w p2) -> b (c p1 p2) h w", p1=2, p2=2) if 81 | (last := channels_[1] < channels[-1]) else Identity(), 82 | Conv2d(channels_[0] * (1 + 3 * last), channels_[1], 1), 83 | ) 84 | ]) for channels_ in zip(channels[1:], channels[2:]) 85 | ]) 86 | self.bottleneck = ModuleList([ 87 | UNet.Block(hidden, (channels[-1], channels[-1]), groups), 88 | UNet.Attention(channels[-1], heads), 89 | UNet.Block(hidden, (channels[-1], channels[-1]), groups), 90 | ]) 91 | self.decoder = ModuleList([ 92 | ModuleList([ 93 | UNet.Block(hidden, (sum(channels_), channels_[0]), groups), 94 | UNet.Block(hidden, (sum(channels_), channels_[0]), groups), 95 | UNet.Attention(channels_[0], heads), 96 | GroupNorm(groups, channels_[0]), 97 | Sequential( 98 | Upsample(None, 2, "nearest") 99 | if channels_[1] > channels[1] else Identity(), 100 | Conv2d(channels_[0], channels_[1], 3, 1, 1), 101 | ) 102 | ]) for channels_ in zip(channels[:1:-1], channels[-2::-1]) 103 | ]) 104 | self.output = Sequential( 105 | Conv2d(channels[1], parameters * channels[0], 3, 1, 1), 106 | Rearrange("b (p c) h w -> p b c h w", p=parameters), 107 | ) 108 | 109 | def forward(self, x: Tensor, y: Tensor, t: Tensor) -> Tensor: 110 | x = self.input(x) 111 | c = self.label(y) + self.time(t) 112 | h = list[Tensor]() 113 | for block1, block2, attention, norm, transform in self.encoder: # type: ignore[assignment] 114 | x = block1(x, c) 115 | h.append(x) 116 | x = block2(x, c) 117 | x = attention(norm(x)) + x 118 | h.append(x) 119 | x = pad(x, (0, x.shape[2] % 2, 0, x.shape[3] % 2)) 120 | x = transform(x) 121 | x = self.bottleneck[0](x, c) 122 | x = self.bottleneck[1](x) 123 | x = self.bottleneck[2](x, c) 124 | for block1, block2, attention, norm, transform in self.decoder: # type: ignore[assignment] 125 | x = crop(x, 0, 0, *h[-1].shape[2:]) 126 | x = block1(torch.cat((x, h.pop()), 1), c) 127 | x = block2(torch.cat((x, h.pop()), 1), c) 128 | x = attention(norm(x)) + x 129 | x = transform(x) 130 | x = self.output(x) 131 | return x 132 | 133 | 134 | class Transformer(Net): 135 | 136 | class Block(Module): 137 | 138 | def __init__(self, width: int, heads: int) -> None: 139 | super().__init__() 140 | self.mlp1 = Sequential(SiLU(), Linear(width, 6 * width)) 141 | self.norm1 = LayerNorm(width) 142 | self.attn = MultiheadAttention(width, heads, batch_first=True) 143 | self.norm2 = LayerNorm(width) 144 | self.mlp2 = Sequential( 145 | Linear(width, width * 4), 146 | FastGELU(), 147 | Linear(width * 4, width), 148 | ) 149 | 150 | def forward(self, x: Tensor, c: Tensor) -> Tensor: 151 | a, b, c, d, e, f = torch.chunk(self.mlp1(c)[:, None], 6, 2) 152 | x = x + a * self.attn(*[b * self.norm1(x) + c] * 3, need_weights=False)[0] 153 | x = x + d * self.mlp2(e * self.norm2(x) + f) 154 | return x 155 | 156 | def __init__( 157 | self, 158 | input: int, 159 | labels: int = 0, 160 | parameters: int = 1, 161 | depth: int = 6, 162 | width: int = 256, 163 | heads: int = 8, 164 | ) -> None: 165 | super().__init__() 166 | self.linear1 = Linear(input, width) 167 | self.position = SinusoidalPositionalEmbedding(width) 168 | self.label = Embedding(labels + 1, width) 169 | self.time = SinusoidalPositionalEmbedding(width) 170 | self.blocks = Sequential( 171 | *[Transformer.Block(width, heads) for _ in range(depth)]) 172 | self.linear2 = Linear(width, input * parameters) 173 | self.rearrange = Rearrange("b l (p e) -> p b l e", p=parameters) 174 | 175 | def forward(self, x: Tensor, y: Tensor, t: Tensor) -> Tensor: 176 | x = self.linear1(x) + self.position(torch.arange(x.shape[1], device=x.device)) 177 | c = self.label(y) + self.time(t) 178 | for block in self.blocks: 179 | x = block(x, c) 180 | return self.rearrange(self.linear2(x)) 181 | -------------------------------------------------------------------------------- /diffusion/noise.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from dataclasses import dataclass 3 | from itertools import accumulate 4 | from typing import Literal 5 | 6 | import torch 7 | from torch import Tensor 8 | 9 | from .base import Noise 10 | from .distribution import Categorical as Cat 11 | from .distribution import Normal as N 12 | 13 | __all__ = ["Gaussian", "Categorical"] 14 | 15 | 16 | @dataclass 17 | class Gaussian(Noise[N]): 18 | parameter: Literal["x", "epsilon", "mu"] = "x" 19 | variance: Literal["fixed", "range", "learned"] = "fixed" 20 | 21 | # TODO add lambda_ parameter to allow DDIM 22 | 23 | def schedule(self, alpha: Tensor) -> None: 24 | delta = alpha.cumprod(0) 25 | self.q1 = delta.sqrt() 26 | self.q2 = (1 - delta).sqrt() 27 | self.q3 = alpha.sqrt() * (1 - delta.roll(1)) / (1 - delta) 28 | self.q4 = delta.roll(1).sqrt() * (1 - alpha) / (1 - delta) 29 | self.q5 = ((1 - alpha) * (1 - delta.roll(1)) / (1 - delta)).sqrt() 30 | if self.parameter == "x": 31 | self.p1, self.p2 = self.q3, self.q4 32 | elif self.parameter == "epsilon": 33 | self.p1 = 1 / alpha.sqrt() 34 | self.p2 = (alpha - 1) / ((1 - delta).sqrt() * alpha.sqrt()) 35 | else: 36 | self.p1, self.p2 = torch.zeros(alpha.shape), torch.ones(alpha.shape) 37 | self.p3 = (1 - alpha).log() 38 | self.p4 = self.q5.log() 39 | 40 | def stationary(self, shape: tuple[int, ...]) -> N: 41 | return N(torch.zeros(shape), torch.ones(shape)) 42 | 43 | def prior(self, x: Tensor, t: Tensor) -> N: 44 | t = t.view(-1, *(1,) * (x.dim() - 1)) 45 | return N(self.q1[t] * x, self.q2[t]) 46 | 47 | def posterior(self, x: Tensor, z: Tensor, t: Tensor) -> N: 48 | t = t.view(-1, *(1,) * (x.dim() - 1)) 49 | return N(self.q3[t] * z + self.q4[t] * x, self.q5[t]) 50 | 51 | def approximate(self, z: Tensor, t: Tensor, hat: Tensor) -> N: 52 | t = t.view(-1, *(1,) * (z.dim() - 1)) 53 | return N( 54 | self.p1[t] * z + self.p2[t] * hat[0], 55 | self.q5[t] 56 | if self.variance == "fixed" 57 | else torch.exp(hat[1] * self.p3[t] + (1 - hat[1]) * self.p4[t]) 58 | if self.variance == "range" 59 | else hat[1], 60 | ) 61 | 62 | 63 | class Categorical(Noise[Cat]): 64 | @abstractmethod 65 | def q(self, t: Tensor) -> Tensor: 66 | raise NotImplementedError 67 | 68 | @abstractmethod 69 | def r(self, t: Tensor) -> Tensor: 70 | raise NotImplementedError 71 | 72 | def prior(self, x: Tensor, t: Tensor) -> Cat: 73 | return Cat(x @ self.r(t)) 74 | 75 | def posterior(self, x: Tensor, z: Tensor, t: Tensor) -> Cat: 76 | return Cat( 77 | (z @ self.q(t).transpose(1, 2)) 78 | * (x @ self.r(t - 1)) 79 | / (x @ self.r(t) * z).sum(2, keepdim=True) 80 | ) 81 | 82 | def approximate(self, z: Tensor, t: Tensor, hat: Tensor) -> Cat: 83 | return self.posterior(hat[0], z, t) 84 | 85 | 86 | class MemoryInefficientCategorical(Categorical): 87 | def schedule(self, alpha: Tensor) -> None: 88 | self._q = self.transition(alpha) 89 | self._r = torch.stack([*accumulate(self._q, torch.mm)]) 90 | 91 | @abstractmethod 92 | def transition(self, alpha: Tensor) -> Tensor: 93 | raise NotImplementedError 94 | 95 | def q(self, t: Tensor) -> Tensor: 96 | return self._q[t] 97 | 98 | def r(self, t: Tensor) -> Tensor: 99 | return self._r[t] 100 | 101 | 102 | @dataclass 103 | class MemoryEfficientCategorical(Categorical): 104 | k: int 105 | 106 | def schedule(self, alpha: Tensor) -> None: 107 | self.alpha = alpha.view(-1, 1, 1) 108 | self.delta = self.alpha.cumprod(0) 109 | self.i = torch.eye(self.k, device=alpha.device)[None] 110 | 111 | @property 112 | @abstractmethod 113 | def a(self) -> Tensor: 114 | raise NotImplementedError 115 | 116 | def q(self, t: Tensor) -> Tensor: 117 | return self.alpha[t] * self.i + (1 - self.alpha[t]) * self.a 118 | 119 | def r(self, t: Tensor) -> Tensor: 120 | return self.delta[t] * self.i + (1 - self.delta[t]) * self.a 121 | 122 | 123 | class Uniform(MemoryEfficientCategorical): 124 | @property 125 | def a(self) -> Tensor: 126 | return torch.ones(self.k, self.k, device=self.i.device) / self.k 127 | 128 | def stationary(self, shape: tuple[int, ...]) -> Cat: 129 | return Cat(torch.full(shape, 1 / self.k)) 130 | 131 | 132 | @dataclass 133 | class Absorbing(MemoryEfficientCategorical): 134 | m: int = -1 135 | 136 | @property 137 | def a(self) -> Tensor: 138 | return self.i[:, self.m].repeat(1, self.k, 1) 139 | 140 | def stationary(self, shape: tuple[int, ...]) -> Cat: 141 | return Cat(torch.eye(self.k)[self.m].repeat(*shape[:-1], 1)) 142 | -------------------------------------------------------------------------------- /diffusion/schedule.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch 4 | from torch import Tensor 5 | 6 | from .base import Schedule 7 | 8 | __all__ = ["Constant", "Linear", "Cosine", "Sqrt"] 9 | 10 | 11 | @dataclass 12 | class Constant(Schedule): 13 | value: float 14 | 15 | def compute(self) -> Tensor: 16 | return torch.full((self.steps + 1,), self.value) 17 | 18 | 19 | @dataclass 20 | class Linear(Schedule): 21 | start: float 22 | end: float 23 | 24 | def compute(self) -> Tensor: 25 | return torch.linspace(self.start, self.end, self.steps + 1) 26 | 27 | 28 | @dataclass 29 | class Cosine(Schedule): 30 | offset: float = 8e-3 31 | exponent: float = 2 32 | 33 | def compute(self) -> Tensor: 34 | t = torch.arange(self.steps + 2) 35 | delta = ((t / (self.steps + 2) + self.offset) / (1 + self.offset) * torch.pi / 36 | 2).cos()**self.exponent 37 | alpha = torch.clip(delta[1:] / delta[:-1], 1e-3, 1) 38 | return alpha 39 | 40 | 41 | @dataclass 42 | class Sqrt(Schedule): 43 | offset: float = 8e-3 44 | 45 | def compute(self) -> Tensor: 46 | t = torch.arange(self.steps + 2) 47 | delta = 1 - torch.sqrt(t / (self.steps + 1) + self.offset) 48 | alpha = torch.clip(delta[1:] / delta[:-1], 0, 0.999) 49 | return alpha 50 | -------------------------------------------------------------------------------- /diffusion/time.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch 4 | from torch import Tensor 5 | 6 | from .base import Time 7 | 8 | 9 | @dataclass 10 | class Discrete(Time): 11 | 12 | def sample(self, steps: int, size: int) -> Tensor: 13 | return torch.randint(1, steps + 1, (size,)) 14 | -------------------------------------------------------------------------------- /diffusion/utils/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable 2 | 3 | import torch 4 | from torch import Tensor 5 | from torch.nn import Conv2d, Module, Parameter 6 | from torch.nn import Sequential as _Sequential 7 | from torch.nn.functional import conv2d 8 | 9 | __all__ = [ 10 | "FastGELU", 11 | "Lambda", 12 | "WeightStdConv2d", 13 | "Swish", 14 | "SinusoidalPositionalEmbedding", 15 | ] 16 | 17 | 18 | class Sequential(_Sequential): 19 | 20 | def forward(self, *input: Any) -> Any: 21 | for module in self._modules.values(): 22 | if type(input) == tuple: 23 | input = module(*input) 24 | else: 25 | input = module(input) 26 | return input 27 | 28 | 29 | class Lambda(Module): 30 | 31 | def __init__(self, function: Callable[..., Any]): 32 | super().__init__() 33 | self.function = function 34 | 35 | def forward(self, *input: Any): # type: ignore 36 | return self.function(*input) 37 | 38 | 39 | class WeightStdConv2d(Conv2d): 40 | 41 | def forward(self, input: Tensor) -> Tensor: 42 | mean = self.weight.mean((1, 2, 3), keepdim=True) 43 | var = self.weight.var((1, 2, 3), unbiased=False, keepdim=True) 44 | return conv2d( 45 | input, 46 | (self.weight - mean) * var.rsqrt(), 47 | self.bias, 48 | self.stride, 49 | self.padding, 50 | self.dilation, 51 | self.groups, 52 | ) 53 | 54 | 55 | class SinusoidalPositionalEmbedding(Module): 56 | 57 | def __init__(self, size: int = 32, base: float = 1e4) -> None: 58 | super().__init__() 59 | self.w = Parameter(base**torch.arange(0, -1, -2 / size)[None]) 60 | 61 | def forward(self, t: Tensor) -> Tensor: 62 | wt = self.w * t[:, None] 63 | return torch.stack((wt.sin(), wt.cos()), 2).flatten(1) 64 | 65 | 66 | class FastGELU(Module): 67 | 68 | def forward(self, x: Tensor) -> Tensor: 69 | return x * torch.sigmoid(1.702 * x) 70 | 71 | 72 | class Swish(Module): 73 | 74 | def forward(self, x: Tensor) -> Tensor: 75 | return x * torch.sigmoid(x) -------------------------------------------------------------------------------- /diffusion/utils/nn/functional.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | 4 | 5 | def swish(x: Tensor) -> Tensor: 6 | return x * torch.sigmoid(x) -------------------------------------------------------------------------------- /docs/.gitignore: -------------------------------------------------------------------------------- 1 | # build output 2 | dist/ 3 | 4 | # generated types 5 | .astro/ 6 | 7 | # dependencies 8 | node_modules/ 9 | 10 | # logs 11 | npm-debug.log* 12 | yarn-debug.log* 13 | yarn-error.log* 14 | pnpm-debug.log* 15 | 16 | # environment variables 17 | .env 18 | .env.production 19 | 20 | # macOS-specific files 21 | .DS_Store 22 | -------------------------------------------------------------------------------- /docs/astro.config.mjs: -------------------------------------------------------------------------------- 1 | import mdx from "@astrojs/mdx"; 2 | import sitemap from "@astrojs/sitemap"; 3 | import tailwind from "@astrojs/tailwind"; 4 | import { defineConfig } from "astro/config"; 5 | import rehypeKatex from "rehype-katex"; 6 | import remarkMath from "remark-math"; 7 | import remarkLayout from "./src/plugins/remark-layout.mjs"; 8 | 9 | // https://astro.build/config 10 | export default defineConfig({ 11 | integrations: [tailwind(), mdx(), sitemap()], 12 | markdown: { 13 | remarkPlugins: [remarkMath, remarkLayout], 14 | rehypePlugins: [rehypeKatex], 15 | }, 16 | site: "https://cabralpinto.github.io", 17 | base: "/modular-diffusion", 18 | redirects: { 19 | "/": "/modular-diffusion/guides/getting-started", 20 | }, 21 | }); 22 | -------------------------------------------------------------------------------- /docs/env.d.ts: -------------------------------------------------------------------------------- 1 | /// 2 | /// 3 | -------------------------------------------------------------------------------- /docs/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "modular-diffusion-docs", 3 | "type": "module", 4 | "version": "0.0.1", 5 | "scripts": { 6 | "dev": "astro dev --host", 7 | "start": "astro dev", 8 | "build": "astro build", 9 | "preview": "astro preview", 10 | "astro": "astro" 11 | }, 12 | "dependencies": { 13 | "@astrojs/mdx": "^0.19.7", 14 | "@astrojs/sitemap": "^2.0.2", 15 | "@astrojs/tailwind": "^4.0.0", 16 | "@fontsource/roboto": "^5.0.8", 17 | "astro": "^2.10.1", 18 | "tailwindcss": "^3.3.3", 19 | "tailwindcss-opentype": "^1.1.0" 20 | }, 21 | "devDependencies": { 22 | "@tailwindcss/typography": "^0.5.9", 23 | "rehype-katex": "^6.0.3", 24 | "remark-math": "^5.1.1", 25 | "sass": "^1.65.1" 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /docs/public/fonts/FiraCode/FiraCode-Bold.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cabralpinto/modular-diffusion/4d919974fcf8ec5108f84122ce18e9a9ba46fd35/docs/public/fonts/FiraCode/FiraCode-Bold.ttf -------------------------------------------------------------------------------- /docs/public/fonts/FiraCode/FiraCode-Light.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cabralpinto/modular-diffusion/4d919974fcf8ec5108f84122ce18e9a9ba46fd35/docs/public/fonts/FiraCode/FiraCode-Light.ttf -------------------------------------------------------------------------------- /docs/public/fonts/FiraCode/FiraCode-Medium.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cabralpinto/modular-diffusion/4d919974fcf8ec5108f84122ce18e9a9ba46fd35/docs/public/fonts/FiraCode/FiraCode-Medium.ttf -------------------------------------------------------------------------------- /docs/public/fonts/FiraCode/FiraCode-Regular.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cabralpinto/modular-diffusion/4d919974fcf8ec5108f84122ce18e9a9ba46fd35/docs/public/fonts/FiraCode/FiraCode-Regular.ttf -------------------------------------------------------------------------------- /docs/public/fonts/FiraCode/FiraCode-SemiBold.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cabralpinto/modular-diffusion/4d919974fcf8ec5108f84122ce18e9a9ba46fd35/docs/public/fonts/FiraCode/FiraCode-SemiBold.ttf -------------------------------------------------------------------------------- /docs/public/fonts/Inter/Inter-Black.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cabralpinto/modular-diffusion/4d919974fcf8ec5108f84122ce18e9a9ba46fd35/docs/public/fonts/Inter/Inter-Black.ttf -------------------------------------------------------------------------------- /docs/public/fonts/Inter/Inter-Bold.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cabralpinto/modular-diffusion/4d919974fcf8ec5108f84122ce18e9a9ba46fd35/docs/public/fonts/Inter/Inter-Bold.ttf -------------------------------------------------------------------------------- /docs/public/fonts/Inter/Inter-ExtraBold.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cabralpinto/modular-diffusion/4d919974fcf8ec5108f84122ce18e9a9ba46fd35/docs/public/fonts/Inter/Inter-ExtraBold.ttf -------------------------------------------------------------------------------- /docs/public/fonts/Inter/Inter-ExtraLight.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cabralpinto/modular-diffusion/4d919974fcf8ec5108f84122ce18e9a9ba46fd35/docs/public/fonts/Inter/Inter-ExtraLight.ttf -------------------------------------------------------------------------------- /docs/public/fonts/Inter/Inter-Light.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cabralpinto/modular-diffusion/4d919974fcf8ec5108f84122ce18e9a9ba46fd35/docs/public/fonts/Inter/Inter-Light.ttf -------------------------------------------------------------------------------- /docs/public/fonts/Inter/Inter-Medium.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cabralpinto/modular-diffusion/4d919974fcf8ec5108f84122ce18e9a9ba46fd35/docs/public/fonts/Inter/Inter-Medium.ttf -------------------------------------------------------------------------------- /docs/public/fonts/Inter/Inter-Regular.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cabralpinto/modular-diffusion/4d919974fcf8ec5108f84122ce18e9a9ba46fd35/docs/public/fonts/Inter/Inter-Regular.ttf -------------------------------------------------------------------------------- /docs/public/fonts/Inter/Inter-SemiBold.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cabralpinto/modular-diffusion/4d919974fcf8ec5108f84122ce18e9a9ba46fd35/docs/public/fonts/Inter/Inter-SemiBold.ttf -------------------------------------------------------------------------------- /docs/public/fonts/Inter/Inter-Thin.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cabralpinto/modular-diffusion/4d919974fcf8ec5108f84122ce18e9a9ba46fd35/docs/public/fonts/Inter/Inter-Thin.ttf -------------------------------------------------------------------------------- /docs/public/images/guides/getting-started/conditional.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cabralpinto/modular-diffusion/4d919974fcf8ec5108f84122ce18e9a9ba46fd35/docs/public/images/guides/getting-started/conditional.png -------------------------------------------------------------------------------- /docs/public/images/guides/getting-started/unconditional-cosine.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cabralpinto/modular-diffusion/4d919974fcf8ec5108f84122ce18e9a9ba46fd35/docs/public/images/guides/getting-started/unconditional-cosine.png -------------------------------------------------------------------------------- /docs/public/images/guides/getting-started/unconditional-linear.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cabralpinto/modular-diffusion/4d919974fcf8ec5108f84122ce18e9a9ba46fd35/docs/public/images/guides/getting-started/unconditional-linear.png -------------------------------------------------------------------------------- /docs/public/images/logo.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cabralpinto/modular-diffusion/4d919974fcf8ec5108f84122ce18e9a9ba46fd35/docs/public/images/logo.ico -------------------------------------------------------------------------------- /docs/public/images/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cabralpinto/modular-diffusion/4d919974fcf8ec5108f84122ce18e9a9ba46fd35/docs/public/images/logo.png -------------------------------------------------------------------------------- /docs/public/images/logo2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cabralpinto/modular-diffusion/4d919974fcf8ec5108f84122ce18e9a9ba46fd35/docs/public/images/logo2.png -------------------------------------------------------------------------------- /docs/public/images/menu.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /docs/public/images/modules/noise-schedule/constant.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cabralpinto/modular-diffusion/4d919974fcf8ec5108f84122ce18e9a9ba46fd35/docs/public/images/modules/noise-schedule/constant.png -------------------------------------------------------------------------------- /docs/public/images/modules/noise-schedule/cosine.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cabralpinto/modular-diffusion/4d919974fcf8ec5108f84122ce18e9a9ba46fd35/docs/public/images/modules/noise-schedule/cosine.png -------------------------------------------------------------------------------- /docs/public/images/modules/noise-schedule/linear.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cabralpinto/modular-diffusion/4d919974fcf8ec5108f84122ce18e9a9ba46fd35/docs/public/images/modules/noise-schedule/linear.png -------------------------------------------------------------------------------- /docs/public/images/modules/noise-schedule/sqrt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cabralpinto/modular-diffusion/4d919974fcf8ec5108f84122ce18e9a9ba46fd35/docs/public/images/modules/noise-schedule/sqrt.png -------------------------------------------------------------------------------- /docs/public/images/modules/noise-type/absorbing.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cabralpinto/modular-diffusion/4d919974fcf8ec5108f84122ce18e9a9ba46fd35/docs/public/images/modules/noise-type/absorbing.png -------------------------------------------------------------------------------- /docs/public/images/modules/noise-type/gaussian.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cabralpinto/modular-diffusion/4d919974fcf8ec5108f84122ce18e9a9ba46fd35/docs/public/images/modules/noise-type/gaussian.png -------------------------------------------------------------------------------- /docs/public/images/modules/noise-type/uniform.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cabralpinto/modular-diffusion/4d919974fcf8ec5108f84122ce18e9a9ba46fd35/docs/public/images/modules/noise-type/uniform.png -------------------------------------------------------------------------------- /docs/public/robots.txt: -------------------------------------------------------------------------------- 1 | User-agent: * 2 | Allow: / 3 | 4 | Sitemap: https://cabralpinto.github.io/sitemap-index.xml -------------------------------------------------------------------------------- /docs/src/components/Article.astro: -------------------------------------------------------------------------------- 1 |
4 | 5 |
6 | 7 | 72 | -------------------------------------------------------------------------------- /docs/src/components/Header.astro: -------------------------------------------------------------------------------- 1 |
4 |
5 | Menu 10 | 11 | Logo 16 |

Modular Diffusion

17 |
18 |
19 | GitHub 23 | Discord 27 | PyPI 31 |
32 |
33 | 34 | 49 | -------------------------------------------------------------------------------- /docs/src/components/Index.astro: -------------------------------------------------------------------------------- 1 | --- 2 | import type { MarkdownHeading } from "@astrojs/markdown-remark"; 3 | 4 | interface Props { 5 | headings: MarkdownHeading[]; 6 | } 7 | 8 | const headings = Astro.props.headings.filter((heading) => heading.depth < 3); 9 | headings[0].text = "Overview"; 10 | --- 11 | 12 | 28 | 29 | 56 | -------------------------------------------------------------------------------- /docs/src/components/Sidebar.astro: -------------------------------------------------------------------------------- 1 | --- 2 | const articles = (await Astro.glob("../pages/[!_]*/*.mdx")) 3 | .sort((a, b) => a.frontmatter.id - b.frontmatter.id) 4 | .reduce<{ [key: string]: { title: string; href: string }[] }>( 5 | (articles, article) => { 6 | let group = article.file.split("/").slice(-2, -1)[0]; 7 | group = `${group[0].toUpperCase()}${group.slice(1)}`; 8 | articles[group] = [ 9 | ...(articles[group] || []), 10 | { title: article.frontmatter.title, href: article.url ?? "" }, 11 | ]; 12 | return articles; 13 | }, 14 | {} 15 | ); 16 | --- 17 | 18 | 38 | -------------------------------------------------------------------------------- /docs/src/components/markdown/Bullet.astro: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cabralpinto/modular-diffusion/4d919974fcf8ec5108f84122ce18e9a9ba46fd35/docs/src/components/markdown/Bullet.astro -------------------------------------------------------------------------------- /docs/src/components/markdown/Code.astro: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cabralpinto/modular-diffusion/4d919974fcf8ec5108f84122ce18e9a9ba46fd35/docs/src/components/markdown/Code.astro -------------------------------------------------------------------------------- /docs/src/components/markdown/Image.astro: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /docs/src/components/markdown/Link.astro: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cabralpinto/modular-diffusion/4d919974fcf8ec5108f84122ce18e9a9ba46fd35/docs/src/components/markdown/Link.astro -------------------------------------------------------------------------------- /docs/src/components/markdown/Note.astro: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cabralpinto/modular-diffusion/4d919974fcf8ec5108f84122ce18e9a9ba46fd35/docs/src/components/markdown/Note.astro -------------------------------------------------------------------------------- /docs/src/components/markdown/Paragraph.astro: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cabralpinto/modular-diffusion/4d919974fcf8ec5108f84122ce18e9a9ba46fd35/docs/src/components/markdown/Paragraph.astro -------------------------------------------------------------------------------- /docs/src/components/markdown/Subtitle.astro: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cabralpinto/modular-diffusion/4d919974fcf8ec5108f84122ce18e9a9ba46fd35/docs/src/components/markdown/Subtitle.astro -------------------------------------------------------------------------------- /docs/src/components/markdown/Title.astro: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cabralpinto/modular-diffusion/4d919974fcf8ec5108f84122ce18e9a9ba46fd35/docs/src/components/markdown/Title.astro -------------------------------------------------------------------------------- /docs/src/env.d.ts: -------------------------------------------------------------------------------- 1 | /// 2 | /// -------------------------------------------------------------------------------- /docs/src/layouts/Layout.astro: -------------------------------------------------------------------------------- 1 | --- 2 | import type { MarkdownLayoutProps } from "astro"; 3 | import Article from "../components/Article.astro"; 4 | import Header from "../components/Header.astro"; 5 | import Index from "../components/Index.astro"; 6 | import Sidebar from "../components/Sidebar.astro"; 7 | import "../styles/fonts.css"; 8 | 9 | type Props = MarkdownLayoutProps<{ 10 | title: string; 11 | index: boolean; 12 | }>; 13 | 14 | const { 15 | headings, 16 | frontmatter: { title, index }, 17 | } = Astro.props; 18 | --- 19 | 20 | 21 | 22 | 23 | 24 | 25 | 30 | {title} | Modular Diffusion 31 | 35 | 36 | 42 | 43 | 44 |
45 |
46 | 47 |
48 |
49 | 50 | { 51 | index && ( 52 | <> 53 |
54 |

55 | 56 | If you spot any typo or technical imprecision, please submit 57 | an issue or pull request to the library's 58 | 59 | GitHub repository 60 | 61 | . 62 | 63 |

64 | 65 | ) 66 | } 67 |
68 | {index ? : ""} 69 |
70 |
71 | 72 | 73 | -------------------------------------------------------------------------------- /docs/src/pages/_drafts/model.mdx: -------------------------------------------------------------------------------- 1 | --- 2 | id: 2.1 3 | title: "diffusion.Model" 4 | index: true 5 | --- 6 | 7 | # diffusion.Model 8 | 9 | ## load() 10 | 11 | ### Parameters 12 | 13 | - `path: str`: Path to model. 14 | 15 | ## save() 16 | 17 | ### Parameters 18 | 19 | - `path: str`: Path to model. 20 | 21 | ## train() 22 | 23 | ### Parameters 24 | 25 | - `epochs: int = 1`: Number of epochs. 26 | - `progress: bool = True`: Show progress bar. 27 | 28 | ### Returns 29 | 30 | - `Iterator[float]`: Training losses. 31 | 32 | ## sample() 33 | 34 | ### Parameters 35 | 36 | - `y: Optional[Tensor] = None`: Labels for conditional sampling. 37 | - `batch: int = 1`: Batch size. 38 | - `progress: bool = True`: Show progress bar. -------------------------------------------------------------------------------- /docs/src/pages/guides/custom-modules.mdx: -------------------------------------------------------------------------------- 1 | --- 2 | id: 1.2 3 | title: "Custom Modules" 4 | index: true 5 | --- 6 | 7 | 8 | 9 | # {frontmatter.title} 10 | 11 | When tinkering with Diffusion Models, the time will come when you need to venture beyond what the base library offers and modify the diffusion process to fit your needs. Modular Diffusion meets this requirement by providing an abstract base class for each module type, which can be extended to define custom behavior. In this tutorial, we provide an overview of each base class and an example of how to extend it. 12 | 13 | > Type annotations 14 | > 15 | > As with all library code, this tutorial adheres to strict type checking standards. Although we recommend typing your code, you may elect to avoid writing type annotations. By skipping this step, however, you will not receive a warning if you try to mix incompatible modules, or other useful intellisense. 16 | 17 | ## Data transform 18 | 19 | In many Diffusion Model applications, the diffusion process takes place in the dataset space. If this is your case, the prebuilt `Identity` data transform module will serve your purposes, leaving your data untouched before applying noise during training. However, a growing number of algorithms, like [Stable Diffusion](https://arxiv.org/abs/2112.10752) and [Diffusion-LM](https://arxiv.org/abs/2205.14217), project data onto a latent space before applying diffusion. 20 | 21 | In the case of Diffusion-LM, the dataset consists of sequences of word IDs, but the diffusion process happens in the word embedding space. This means you need a way of converting sequences of word IDs into sequences of embeddings, and train the embeddings along with the Diffusion Model. In Modular Diffusion, this can be achieved by extending the `Data` base class and implement its `encode` and `decode` methods. The former projects the data into the latent space and the latter retrieves it to the dataset space. Let's take a look at how you could implement the aforementioned transform: 22 | 23 | ```python 24 | from diffusion.base import Data 25 | 26 | @dataclass 27 | class Embedding(Data): 28 | count: int = 2 29 | dimension: int = 256 30 | 31 | def __post_init__(self) -> None: 32 | self.embedding = nn.Embedding(self.count, self.dimension) 33 | 34 | def encode(self, w: Tensor) -> Tensor: 35 | return self.embedding(w) 36 | 37 | def decode(self, x: Tensor) -> Tensor: 38 | return torch.cdist(x, self.embedding.weight).argmin(-1) 39 | ``` 40 | 41 | In the `encode` method, we are transforming the input tensor `w` into an embedding tensor using the learned embedding layer. The `decode` method reverses this operation, by finding the most similar embedding in the embedding weight matrix to each vector in `x`. 42 | 43 | Data transforms can also be useful in cases where they have no trainable parameters. For example, the `Categorical` noise module operates over one-hot vectors, which are very memory-inneficient. To mitigate this, you may store your data as a list of labels and use the `OneHot` data transform module to transform it into one-hot vectors on a batch-by-batch basis, saving you a lot of memory. Or your data transform can just be a frozen variational autoencoder, like in [Stable Diffusion](https://arxiv.org/abs/2112.10752). For further details, check out our [Text Generation](/modular-diffusion/guides/text-generation) and [Image Generation](/modular-diffusion/guides/image-generation) tutorials. 44 | 45 | ## Noise schedule 46 | 47 | You can implement your own custom diffusion schedule by extending the abstract `Schedule` base class and implement its only abstract method, `compute`. This method is responsible for providing a tensor containing the values for $\alpha_t$ for $t \in \{0,\dots,T\}$. As an example, let's implement the `Linear` schedule, which is already included in the library: 48 | 49 | ```python 50 | from dataclasses import dataclass 51 | from diffusion.base import Schedule 52 | 53 | @dataclass 54 | class Linear(Schedule): 55 | start: float 56 | end: float 57 | 58 | def compute(self) -> Tensor: 59 | return torch.linspace(self.start, self.end, self.steps + 1) 60 | ``` 61 | 62 | Given that `steps` is already a parameter in the base class, all we need to do is define `start` and `end` parameters, and use them to compute the $a_t$ values. Then, we can initialize the schedule with the syntax `Linear(steps, start, end)`. 63 | ## Probability distribution 64 | 65 | In the diffusion process, the chosen probability distribution plays a crucial role in modeling the noise that guides the transition between different states. The library comes prepackaged with a growing set of commonly used distributions, such as the `Normal` distribution, but different applications or experimental setups might require you to implement your own. 66 | 67 | To define a custom distribution, you'll need to extend the `Distribution` base class and implement three key methods: `sample`, which draws a sample from the distribution and returns a tuple containing the sampled value and the applied noise (or `None` if not applicable); `nll`, which computes the negative log-likelihood of the given tensor `x`; and `dkl`, which computes the Kullback-Leibler Divergence between the distribution and another provided as `other`. Take, for example, the `Normal` distribution, included in the library: 68 | 69 | ```python 70 | @dataclass 71 | 72 | class Normal(Distribution): 73 | mu: Tensor 74 | sigma: Tensor 75 | 76 | def sample(self) -> tuple[Tensor, Tensor]: 77 | epsilon = torch.randn(self.mu.shape, device=self.mu.device) 78 | return self.mu + self.sigma * epsilon, epsilon 79 | 80 | def nll(self, x: Tensor) -> Tensor: 81 | return (0.5 * ((x - self.mu) / self.sigma)**2 + \ 82 | (self.sigma * 2.5066282746310002).log()) 83 | 84 | def dkl(self, other: Self) -> Tensor: 85 | return (torch.log(other.sigma / self.sigma) + \ 86 | (self.sigma**2 + (self.mu - other.mu)**2) / (2 * other.sigma**2) - 0.5) 87 | ``` 88 | 89 | > Parameter shapes 90 | > 91 | > The distribution parameters are represented as tensors with the same size as a batch. This essentially means that a `Distribution` object functions as a collection of distributions, where each individual element in a batch corresponds to a unique distribution. For instance, each pixel in a batch of images is associated with its own `mu` and `sigma` values. 92 | 93 | ## Noise type 94 | 95 | In most Diffusion Model applications, the standard choice of noise is Gaussian, which is already bundled within the library. However, there may be scenarios where you want to experiment with variations of the standard Gaussian noise, as in DDIM introduced in [Song et al. 2020](https://arxiv.org/abs/2010.02502), or venture into entirely different noise types, like the one used in D3PM, introduced in [Austin et al. (2021)](https://arxiv.org/abs/2107.03006). To create your own unique noise behavior, you will need to extend the abstract `Noise` base class, and implement each one of the following methods: 96 | 97 | - `schedule(self, alpha: Tensor) -> None`: This method is intended for precomputing resources based on the noise schedule $\alpha_t$ for $t \in {0,\dots,T}$. This can be beneficial for performance reasons when some calculations can be done ahead of time. A common use is calculating $\bar{\alpha}_{t}=\prod_{t=1}^{T}\alpha_{t}$. 98 | - `stationary(self, shape: tuple[int, ...]) -> Distribution`: This method is tasked with computing the stationary distribution $q(x_T)$, i.e., the noise distribution at the final time step, given a target shape. 99 | - `prior(self, x: Tensor, t: Tensor) -> Distribution`: This method computes the prior distribution $q(x_t | x_0)$, i.e., the distribution of the noisy images $x_t$ or `z` given the initial image $x_0$ or `x`. 100 | - `posterior(self, x: Tensor, z: Tensor, t: Tensor) -> Distribution`: This method computes the posterior distribution $q(x_{t-1} | x_t, x_0)$, i.e., the distribution of the less noisy images $x_{t-1}$ given the current noisy image $x_t$ or `z` and the initial image $x_0$ or `x`. 101 | - `approximate(self, z: Tensor, t: Tensor, hat: Tensor) -> Distribution`: This method computes the approximate posterior distribution $p_\theta(x_{t-1} | x_t)$, i.e., the distribution of the less noisy images $x_{t-1}$ given the current noisy image $x_t$ or `z`. This is an approximation to the true posterior distribution that is easier to sample from or compute. The tensor `hat` is the output of the denoiser network containing the predicted parameters -- named this way because predicted values often are denoted with a hat, e.g., $\hat{\epsilon}$. 102 | 103 | If you aim to replicate a specific research paper, only need to translate the mathematical expressions into code. For example, the original DDPM paper yields the following equations: 104 | 105 | - $q(x_{T})=\mathcal{N}(x_T; 0, \text{I})$ 106 | - $q(x_{t}|x_{0})=\mathcal{N}(x_{t};\sqrt{\bar{\alpha}_{t}}x_{t-1},(1 - \bar{\alpha}_{t})\text{I})$ 107 | - $q(x_{t-1}|x_{t},x_{0})=\mathcal{N}(x_{t};\frac{\sqrt{\alpha_t}(1-\bar\alpha_{t-1})x_{t} + \sqrt{\bar\alpha_{t-1}}(1-\alpha_t)x_0}{1 -\bar\alpha_{t}},\frac{(1 - \alpha_t)(1 - \bar\alpha_{t-1})}{1 -\bar\alpha_{t}}\text{I})$ 108 | - $p_\theta(x_{t-1} | x_t) = \mathcal{N}(x_{t};\frac{1}{\sqrt{\alpha_t}}x_t - \frac{1 - \alpha_t}{\sqrt{1 - \bar\alpha_t}\sqrt{\alpha_t}}\epsilon,\frac{(1 - \alpha_t)(1 - \bar\alpha_{t-1})}{1 -\bar\alpha_{t}}\text{I})$ 109 | 110 | where $\bar{\alpha}_{t}=\prod_{t=1}^{T}\alpha_{t}$ is calculated beforehand for better performance. In Modular Diffusion, here's how we could implement this type of Gaussian noise: 111 | 112 | ```python 113 | from diffusion.base import Noise 114 | from diffusion.distribution import Normal as N 115 | 116 | @dataclass 117 | class Gaussian(Noise[N]): 118 | def schedule(self, alpha: Tensor) -> None: 119 | self.alpha = alpha 120 | self.delta = alpha.cumprod(0) 121 | 122 | def stationary(self, shape: tuple[int, ...]) -> N: 123 | return N(torch.zeros(shape), torch.ones(shape)) 124 | 125 | def prior(self, x: Tensor, t: Tensor) -> N: 126 | t = t.view(-1, *(1,) * (x.dim() - 1)) 127 | return N(self.alpha[t].sqrt() * x, (1 - self.delta[t]).sqrt()) 128 | 129 | def posterior(self, x: Tensor, z: Tensor, t: Tensor) -> N: 130 | t = t.view(-1, *(1,) * (x.dim() - 1)) 131 | mu = self.alpha[t].sqrt() * (1 - self.delta[t - 1]) * z 132 | mu += self.delta[t - 1].sqrt() * (1 - self.alpha[t]) * x 133 | mu /= (1 - self.delta[t]) 134 | sigma = (1 - self.alpha[t]) * (1 - self.delta[t - 1]) / (1 - self.delta[t]) 135 | sigma = sigma.sqrt() 136 | return N(mu, sigma) 137 | 138 | def approximate(self, z: Tensor, t: Tensor, hat: Tensor) -> N: 139 | t = t.view(-1, *(1,) * (z.dim() - 1)) 140 | mu = (z - (1 - self.alpha[t]) / (1 - self.delta[t]).sqrt() * hat[0]) 141 | mu /= self.alpha[t].sqrt() 142 | sigma = (1 - self.alpha[t]) * (1 - self.delta[t - 1]) / (1 - self.delta[t]) 143 | sigma = sigma.sqrt() 144 | return N(mu, sigma) 145 | ``` 146 | 147 | > Broadcasting 148 | > 149 | > You will notice that some methods start with a statement that reshapes the tensor `t`. This only done to allow broadcasting of the tensors in the subsequent operations. For instance, in the `prior` method, we need to multiply `self.alpha[t].sqrt()` by `x`, but `self.alpha` has shape `[t]` and `x` has shape `[b, c, h, w]`. By reshaping `t` to `[b, 1, 1, 1]`, we can multiply `self.alpha[t].sqrt()` by `x` without any issues. 150 | 151 | The `schedule` method precomputes `alpha` and `delta` (cumulative product of `alpha`) values, which are used in the other methods. The `stationary` method defines the initial noise distribution, while `prior`, `posterior`, and `approximate` methods implement the corresponding mathematical equations for the prior, posterior, and approximate posterior distributions. Collectively, these methods define the complete Gaussian noise model from the original DDPM paper. Note that it is possible to achieve a more efficient solution by precomputing some of the recurrent expressions used in the methods. 152 | 153 | ## Denoiser neural network 154 | 155 | Modular Diffusion comes with general-use `UNet` and `Transformer` classes, which have proven to be effective denoising networks in the context of Diffusion Models. However, it is not uncommon to see authors make modifications to these networks to achieve even better results. To design your own original network, extend the base abstract `Net` class. This class acts as only a thin wrapper over the standard Pytorch `nn.Module` class, meaning you can use it exactly the same way. The `forward` method should take three tensor arguments: the noisy input `x`, the conditioning matrix `y`, and the diffusion time steps `t`. 156 | 157 | > Network output shape 158 | > 159 | > When creating your neural network, it's important to remember that the first dimension of its output will be interpreted as the parameter index, irrespective of the number of parameters being predicted. For instance, if your network is predicting both the mean and variance of noise in an image, the output shape should be `[2, c, h, w]`. But even if you're predicting only the mean, the shape should be `[1, c, h, w]` -- not `[c, h, w]`. 160 | 161 | In scenarios where your network requires only a post-processing step, such as applying a `Softmax` function, there's no need to create an entirely new network class. Modular Diffusion allows for a more concise approach using the pipe operator, as shown in the [Getting Started](/modular-diffusion/guides/getting-started) tutorial: 162 | 163 | ```python 164 | from diffusion.net import Transformer 165 | from torch.nn import Softmax 166 | 167 | net = Transformer(input=512) | Softmax(3) 168 | ``` 169 | 170 | ## Loss function 171 | 172 | In each training step, your `Model` instance creates a `Batch` object, which contains all the information you need about the current batch to compute the corresponding loss. To create a custom loss function, you can extend from the `Loss` base class and implement the `compute` method, where the loss is calculated based on the current batch. Let's start by implementing $L_\text{simple}$ introduced in [Ho et al. 2020](https://arxiv.org/abs/2006.11239). The formula for this loss function is $\mathbb{E}\left[ \lvert\lvert \epsilon - \hat{\epsilon}_\theta \rvert\rvert ^2 \right]$, where $\epsilon$ is the noise added and $\hat{\epsilon}_\theta$ is the predicted noise. 173 | 174 | ```python 175 | from diffusion.base import Distribution, Loss 176 | 177 | class Simple(Loss[Distribution]): 178 | def compute(self, batch: Batch[Distribution]) -> Tensor: 179 | return ((batch.epsilon - batch.hat[0])**2).mean() 180 | ``` 181 | 182 | Notice how we parametrize the `Loss` and `Batch` classes with the `Distribution` type. This just tells your IDE you can use this loss class for any kind of distribution. If you'd like to make a loss function that is only compatible with, say, `Normal` distributions, you should specify this inside the square brackets. Another thing to note is how we assume that the first parameter in the denoiser neural network output `hat` (named this way because predictions are often denoted with a little hat) is $\hat{\epsilon}_\theta$. You can alter this behavior by changing the index or even make it parametrizable with a class property. 183 | 184 | In certain scenarios, you might need not to compute your loss using `batch.hat` directly but instead utilize the approximate posterior distribution $p_\theta(x_{t-1} | x_t)$, which itself is estimated from `batch.hat` in the `Noise` module. This is the case when you need to compute the variational lower bound (VLB), the original loss function utilized to train Diffusion Models. The formula for the VLB is expressed as: 185 | 186 | $$\begin{aligned}L_\text{vlb} & = \mathbb{E}_{q(x_{1}|x_0)}\left[\log p_{\theta}(x_0|x_1)\right] \\ & - \sum_{t=2}^{T} \mathbb{E}_{q(x_{t}|x_0)}\left[D_{KL}(q(x_{t-1}|x_t, x_0)||p_{\theta}(x_{t-1}|x_t))\right] \\ & - D_{KL}(q(x_T|x_0)||p(x_T))\end{aligned}$$ 187 | 188 | Considering that the $D_{KL}(q(x_T|x_0)||p(x_T))$ term is assumed to be 0 in the context of Diffusion Models, you can implement this function as follows: 189 | 190 | ```python 191 | class VLB(Loss[Distribution]): 192 | def compute(self, batch: Batch[Distribution]) -> Tensor: 193 | t = batch.t.view(-1, *(1,) * (batch.x.ndim - 1)) 194 | return batch.q.dkl(batch.p).where(t > 1, batch.p.nll(batch.x)).mean() 195 | ``` 196 | 197 | Here, `batch.p` and `batch.q` represent $p_\theta(x_{t-1} | x_t)$ and $q(x_{t-1} | x_t, x_0)$, respectively. For a full list of `Batch` properties, check out the library's [API Reference](/modular-diffusion/modules/loss-function#training-batch). 198 | 199 | On the other hand, if you wish to train your model using a hybrid loss function that is a linear combination of two or more existing functions, you can do so without creating a new `Hybrid` module. For instance, to combine the `Simple` and `VLB` loss functions, as proposed in [Nichol & Dhariwal (2021)](https://arxiv.org/abs/2102.09672), you can use the following syntax. 200 | 201 | ```python 202 | from diffusion.loss import Simple, VLB 203 | 204 | loss = Simple(parameter="epsilon") + 0.001 * VLB() 205 | ``` 206 | 207 | ## Guidance 208 | 209 | As of right now, `ClassifierFree` guidance is hardcoded into the diffusion process, and there is no way of extending the base `Guidance` class, unless you create your own custom `Model` class. You can expect this behavior to change in an upcoming release. Please refer to our official [Issue Tracker](https://github.com/cabralpinto/modular-diffusion/issues) for updates. 210 | -------------------------------------------------------------------------------- /docs/src/pages/guides/getting-started.mdx: -------------------------------------------------------------------------------- 1 | --- 2 | id: 1.1 3 | title: "Getting Started" 4 | index: true 5 | --- 6 | 7 | # {frontmatter.title} 8 | 9 | Welcome to Modular Diffusion! This tutorial highlights the core features of the package and will put you on your way to prototype and train your own Diffusion Models. For more advanced use cases and further details, check out our other tutorials and the library's API reference. 10 | 11 | > Prerequisites 12 | > 13 | > This tutorial assumes basic familiarity with Diffusion Models. If you are just hearing about Diffusion Models, you can find out more in one of the [many tutorials out there](https://diff-usion.github.io/Awesome-Diffusion-Models/#introductory-posts). 14 | 15 | ## Install the package 16 | 17 | Before you start, please install Modular Diffusion in your local Python environment by running the following command: 18 | 19 | ```sh 20 | python -m pip install modular-diffusion 21 | ``` 22 | 23 | Additionally, ensure you've installed the correct [Pytorch distribution](https://pytorch.org/get-started/locally/) for your system. 24 | 25 | ## Train a simple model 26 | 27 | The first step before training a Diffusion Model is to load your dataset. In this example, we will be using [MNIST](http://yann.lecun.com/exdb/mnist/), which includes 70,000 grayscale images of handwritten digits, and is a great simple dataset to prototype your image models. We are going to load MNIST with [Pytorch Vision](https://pytorch.org/vision/stable/index.html), but you can load your dataset any way you like, as long as it results in a `torch.Tensor` object. We are also going to discard the labels and scale the data to the commonly used $[-1, 1]$ range. 28 | 29 | ```python 30 | import torch 31 | from torchvision.datasets import MNIST 32 | from torchvision.transforms import ToTensor 33 | 34 | x, _ = zip(*MNIST(str(input), transform=ToTensor(), download=True)) 35 | x = torch.stack(x) * 2 - 1 36 | ``` 37 | 38 | Let's build our Diffusion Model next. Modular Diffusion provides you with the `diffusion.Model` class, which takes as parameters a **data transform**, a **noise schedule**, a **noise type**, a **denoiser neural network**, and a **loss function**, along with other optional parameters. You can import prebuilt components for these parameters from the different modules inside Modular Diffusion or build your own. Let's take a look at a simple example which replicates the architecture introduced in [Ho et al. (2020)](https://arxiv.org/abs/2006.11239), using only prebuilt components: 39 | 40 | ```python 41 | import diffusion 42 | from diffusion.data import Identity 43 | from diffusion.loss import Simple 44 | from diffusion.net import UNet 45 | from diffusion.noise import Gaussian 46 | from diffusion.schedule import Linear 47 | 48 | model = diffusion.Model( 49 | data=Identity(x, batch=128, shuffle=True), 50 | schedule=Linear(1000, 0.9999, 0.98), 51 | noise=Gaussian(parameter="epsilon", variance="fixed"), 52 | net=UNet(channels=(1, 64, 128, 256)), 53 | loss=Simple(parameter="epsilon"), 54 | device="cuda" if torch.cuda.is_available() else "cpu", 55 | ) 56 | ``` 57 | 58 | You might have noticed that we also added a `device` parameter to the model, which is important if you're looking to train on the GPU. We are now all set to train and sample from the model. We will train the model for 20 epochs and sample 10 images from it. 59 | 60 | ```python 61 | losses = [*model.train(epochs=20)] 62 | z = model.sample(batch=10) 63 | ``` 64 | 65 | > Tip 66 | > 67 | > If you are getting a `Process killed` message when training your model, try reducing the batch size in the data module. This error is caused by running out of RAM. 68 | 69 | The `sample` function returns a tensor with the same shape as the dataset tensor, but with an extra diffusion time dimension. In this case, the dataset has shape `[b, c, h, w]`, so our output `z` has shape `[t, b, c, h, w]`. Now we just need to rearrange the dimensions of the output tensor to produce one final image. 70 | 71 | ```python 72 | from einops import rearrange 73 | from torchvision.utils import save_image 74 | 75 | z = z[torch.linspace(0, z.shape[0] - 1, 10).int()] 76 | z = rearrange(z, "t b c h w -> c (b h) (t w)") 77 | save_image((z + 1) / 2, "output.png") 78 | ``` 79 | 80 | And that's it! The image we just saved should look something like this: 81 | 82 | ![Random numbers being generated from noise.](/modular-diffusion/images/guides/getting-started/unconditional-linear.png) 83 | 84 | ### Add a validation loop 85 | 86 | You might have noticed that the `train` method returns a generator object. This is to allow you to validate the model between epochs inside a `for` loop. For instance, you can see how your model is coming along by sampling from it between each training epoch, rather than only at the end. 87 | 88 | ```python 89 | for epoch, loss in enumerate(model.train(epochs=20)): 90 | z = model.sample(batch=10) 91 | z = z[torch.linspace(0, z.shape[0] - 1, 10).int()] 92 | z = rearrange(z, "t b c h w -> c (b h) (t w)") 93 | save_image((z + 1) / 2, f"{epoch}.png") 94 | ``` 95 | 96 | > Tip 97 | > 98 | > If you're only interested in seeing the final results, sample the model with the following syntax: `*_, z = model.sample(batch=10)`. In this example, this will yield a tensor with shape `[b, c, h, w]` containing only the generated images. 99 | 100 | ### Swap modules 101 | 102 | The beauty in Modular Diffusion is how easy it is to make changes to an existing model. To showcase this, let's plug in the `Cosine` schedule introduced in [Nichol & Dhariwal (2021)](https://arxiv.org/abs/2102.09672). All it does is destroy information at a slower rate in the forward diffusion process, which was shown to improve sample quality. 103 | 104 | ```python 105 | from diffusion.schedule import Cosine 106 | 107 | model = diffusion.Model( 108 | data=Identity(x, batch=128, shuffle=True), 109 | schedule=Cosine(steps=1000), # changed the schedule! 110 | noise=Gaussian(parameter="epsilon", variance="fixed"), 111 | net=UNet(channels=(1, 64, 128, 256)), 112 | loss=Simple(parameter="epsilon"), 113 | device="cuda" if torch.cuda.is_available() else "cpu", 114 | ) 115 | ``` 116 | 117 | By keeping the rest of the code the same, we end up with the following result: 118 | 119 | ![Random numbers being generated from noise at a slower rate.](/modular-diffusion/images/guides/getting-started/unconditional-cosine.png) 120 | 121 | You can see that, because we used the cosine schedule, the denoising process is more gradual compared to the previous example. 122 | 123 | ## Train a conditional model 124 | 125 | In most Diffusion Model applications, you'll want to be able to condition the generation process. To show you how you can do this in Modular Diffusion, we'll continue working with the MNIST dataset, but this time we want to be able to control what digits we generate. Like before, we're going to load and preprocess the dataset, but this time we want to keep the labels, which tell us what number is in each image. We are also going to move the labels one unit up, since the label 0 is reserved for the null class. 126 | 127 | ```python 128 | x, y = zip(*MNIST(str(input), transform=ToTensor(), download=True)) 129 | x, y = torch.stack(x) * 2 - 1, torch.tensor(y) + 1 130 | ``` 131 | 132 | Once again, let's assemble our Diffusion Model. This time, we will add the labels `y` in our data transform object and provide the number of labels to our denoiser network. Let's also add classifier-free guidance to the model, a technique introduced in [Ho et al. (2022)](https://arxiv.org/abs/2207.12598) to improve sample quality in conditional generation, at the cost of extra sample time and less sample variety. 133 | 134 | ```python 135 | from diffusion.guidance import ClassifierFree 136 | 137 | model = diffusion.Model( 138 | data=Identity(x, y, batch=128, shuffle=True), # added y in here! 139 | schedule=Cosine(steps=1000), 140 | noise=Gaussian(parameter="epsilon", variance="fixed"), 141 | net=UNet(channels=(1, 64, 128, 256), labels=10), # added labels here! 142 | guidance=ClassifierFree(dropout=0.1, strength=2), # added classifier guidance! 143 | loss=Simple(parameter="epsilon"), 144 | device="cuda" if torch.cuda.is_available() else "cpu", 145 | ) 146 | ``` 147 | 148 | One final change we will be making compared to our previous example is to provide the labels of the images we wish to generate to the `sample` function. As an example, let's request one image of each digit by replacing `model.sample(batch=10)` with `model.sample(y=torch.arange(1, 11))`. We then end up with the following image: 149 | 150 | ![Numbers 0 through 9 being generated from noise.](/modular-diffusion/images/guides/getting-started/conditional.png) 151 | 152 | Pretty cool, uh? You can see how now we can choose which digit we sample from the model. This is, of course, only the tip of the iceberg. If you are looking more advanced conditioning techniques, such as the one used in [DALL·E 2](https://openai.com/dall-e-2), please refer to our [Image Generation Guide](/modular-diffusion/guides/image-generation). 153 | 154 | ## Save and load the model 155 | 156 | Once you're done training your Diffusion Model, you may wish to store it for later. Modular Diffusion provides you with an intuitive interface to achieve this. Below is the syntax for saving the model: 157 | 158 | ```python 159 | model.save("model.pt") 160 | ``` 161 | 162 | In order to load it back, use the following snippet: 163 | 164 | ```python 165 | from pathlib import Path 166 | 167 | if Path("model.pt").exists() 168 | model.load("model.pt") 169 | ``` 170 | 171 | Remember to always initialize the model prior to loading it, preferably with the same parameters you trained the model with. The `load` function will then populate the model weights with the ones you have saved. 172 | 173 | > Warning 174 | > 175 | > In some scenarios, you might want to introduce changes to the model architecture before you load it in. In these cases, it is important to keep in mind that structures that hold trainable weights, like the `net` parameter, cannot be changed, or your script will crash. Moreover, your Diffusion Model will most likely need to be trained for a few additional epochs if you make any changes to its parameters. 176 | 177 | ## Create your own modules 178 | 179 | As you've seen, Modular Diffusion provides you with a library of prebuilt modules you can plug into and out of your model according to your needs. Sometimes, however, you may need to customize the model behavior beyond what the library already offers. To address this, each module type has an abstract base class, which serves as a blueprint for new modules. To create your own custom module, simply inherit from the base class and implement the required methods. 180 | 181 | Suppose, for example, you want to implement your own custom noise schedule. You can achieve this by extending the abstract `Schedule` base class and implement its only abstract method, `compute`. This method is responsible for providing a tensor containing the values for $\alpha_t$ for $t \in \{0,\dots,T\}$. As an example, let's reimplement the `Linear` schedule: 182 | 183 | ```python 184 | from dataclasses import dataclass 185 | from diffusion.base import Schedule 186 | 187 | @dataclass 188 | class Linear(Schedule): 189 | start: float 190 | end: float 191 | 192 | def compute(self) -> Tensor: 193 | return torch.linspace(self.start, self.end, self.steps + 1) 194 | ``` 195 | 196 | Given that `steps` is already a parameter in the base class, all we need to do is define `start` and `end` parameters, and use them to compute the $a_t$ values. Now you can use your custom module in your `diffusion.Model` just like you did with the prebuilt ones! For more detailed guidance on extending each module type check out our [Custom Modules Tutorial](/modular-diffusion/guides/custom-modules). 197 | 198 | Another neat feature of Modular Diffusion is it provides an intuitive way to combine existing modules without having to create new ones. For instance, sometimes you'll want to train the model on a hybrid loss function that is a linear combination of two or more functions. In their paper, [Nichol & Dhariwal (2021)](https://arxiv.org/abs/2102.09672) introduced such a loss function, which is a linear combination of the simple loss function proposed by [Ho et al. (2020)](https://arxiv.org/abs/2006.11239) and the [variational lower bound (VLB)](https://en.wikipedia.org/wiki/Evidence_lower_bound): 199 | 200 | $$L_\text{hybrid}=L_\text{simple}+0.001 \cdot L_\text{vlb}$$ 201 | 202 | With Modular Diffusion, rather than creating a custom hybrid loss module, you can conveniently achieve this by combining the `Simple` and `VLB` modules: 203 | 204 | ```python 205 | from diffusion.loss import Simple, VLB 206 | 207 | loss = Simple(parameter="epsilon") + 0.001 * VLB() 208 | ``` 209 | 210 | Similarly, you can append post-processing layers to your denoiser network with the pipe operator, without the need to create a new `Net` module: 211 | 212 | ```python 213 | from diffusion.net import Transformer 214 | from torch.nn import Softmax 215 | 216 | net = Transformer(input=512) | Softmax(2) 217 | ``` 218 | -------------------------------------------------------------------------------- /docs/src/pages/guides/image-generation.mdx: -------------------------------------------------------------------------------- 1 | --- 2 | id: 1.3 3 | title: "Image Generation" 4 | index: false 5 | --- 6 | 7 | # {frontmatter.title} 8 | 9 | *This page is under construction. Please check back later.* -------------------------------------------------------------------------------- /docs/src/pages/guides/text-generation.mdx: -------------------------------------------------------------------------------- 1 | --- 2 | id: 1.4 3 | title: "Text Generation" 4 | index: false 5 | --- 6 | 7 | # {frontmatter.title} 8 | 9 | *This page is under construction. Please check back later.* -------------------------------------------------------------------------------- /docs/src/pages/modules/data-transform.mdx: -------------------------------------------------------------------------------- 1 | --- 2 | id: 2.3 3 | title: "Data Transform" 4 | index: true 5 | --- 6 | 7 | # {frontmatter.title} 8 | 9 | In many Diffusion Models, the diffusion process unfolds within the **dataset space**. However, a growing number of algorithms, like [Stable Diffusion](https://arxiv.org/abs/2112.10752) project data onto a **latent space** before applying diffusion. Modular Diffusion includes an `Identity` transform to allow you to use your data as-is, but also ships with a collection of other data transforms. 10 | 11 | > Notation 12 | > 13 | > Throughout this page, we use $x$ rather than $x_0$ to denote the transformed data for increased readability. Any indexation to $x$ should be interpreted as accessing its individual elements. 14 | 15 | ## Identity transform 16 | 17 | Does not alter the input data. The transform is given by: 18 | 19 | - $x = w$ 20 | - $w = x$. 21 | 22 | ### Parameters 23 | 24 | - `w` -> Input tensor $w$. 25 | - `y` (default: `None`) -> Optional label tensor $y$. 26 | - `batch` (default: `1`) -> Number of samples per training batch. 27 | - `shuffle` (default: `True`) -> Whether to shuffle the data before each epoch. 28 | 29 | ### Example 30 | 31 | ```python 32 | import torch 33 | from diffusion.data import Identity 34 | 35 | w = torch.tensor([[1, 2, 3]]) 36 | data = Identity(w) 37 | x = data.transform(next(data)) 38 | # x = tensor([[1, 2, 3]]) 39 | ``` 40 | 41 | ## One-hot vector transform 42 | 43 | Represents the input data as one-hot vectors. The transform is given by: 44 | 45 | - $x_{\dots ij} =\begin{cases} 1 & \text{if } j = w_{\dots i} \\0 & \text{otherwise}\end{cases}$ 46 | - $w_{\dots i} = \underset{\text{j}}{\text{argmax}}(x_{\dots ij})$. 47 | 48 | ### Parameters 49 | 50 | - `w` -> Input tensor $w$. 51 | - `y` (default: `None`) -> Optional label tensor $y$. 52 | - `k` -> Number of categories $k$. 53 | - `batch` (default: `1`) -> Number of samples per training batch. 54 | - `shuffle` (default: `True`) -> Whether to shuffle the data before each epoch. 55 | 56 | ### Example 57 | 58 | ```python 59 | import torch 60 | from diffusion.data import OneHot 61 | 62 | w = torch.tensor([[0, 2, 2]]) 63 | data = OneHot(w, k=3) 64 | x = data.transform(next(data)) 65 | # x = tensor([[[1, 0, 0], 66 | # [0, 0, 1], 67 | # [0, 0, 1]]]) 68 | ``` 69 | 70 | ## Embedding space transform 71 | 72 | Represents the input data in the embedding space. The embedding matrix is initialized with random values and **updated during training**. Let $\text{E} \in \mathbb{R}^{k \times d}$ be the embedding matrix, where $k$ is the number of categories and $d$ is the embedding dimension. Then the transform is defined as: 73 | 74 | - $x_{\dots ij} = \text{E}_{w_{\dots i}j}$ 75 | - $w_{\dots i} = \underset{\text{k}}{\text{argmin}}\left(\underset{\text{i, k}}{\text{cdist}}\left(x_{\dots ij}, \text{E}_{kj}\right)\right)$. 76 | 77 | ### Parameters 78 | 79 | - `w` -> Input tensor $w$. 80 | - `y` (default: `None`) -> Optional label tensor $y$. 81 | - `k` -> Number of categories $k$. 82 | - `d` -> Embedding dimension $d$. 83 | - `batch` (default: `1`) -> Number of samples per training batch. 84 | - `shuffle` (default: `True`) -> Whether to shuffle the data before each epoch. 85 | 86 | ### Example 87 | 88 | ```python 89 | import torch 90 | from diffusion.data import Embedding 91 | 92 | w = torch.tensor([[0, 2, 2]]) 93 | data = Embedding(w, k=3, d=5) 94 | x = data.transform(next(data)) 95 | # x = tensor([[[0.201, -0.415, 0.683, -0.782, 0.039], 96 | # [-0.509, 0.893, 0.102, -0.345, 0.623], 97 | # [-0.509, 0.893, 0.102, -0.345, 0.623]]]) 98 | ``` 99 | 100 | -------------------------------------------------------------------------------- /docs/src/pages/modules/denoising-network.mdx: -------------------------------------------------------------------------------- 1 | --- 2 | id: 2.7 3 | title: "Denoising Network" 4 | index: true 5 | visualizations: maybe 6 | --- 7 | 8 | # {frontmatter.title} 9 | 10 | The backbone of Diffusion Models is a denoising network, which is trained to gradually denoise data. While earlier works used a **U-Net** architecture, newer research has shown that **Transformers** can be used to achieve comparable or superior results. Modular Diffusion ships with both types of denoising network. Both are implemented in Pytorch and thinly wrapped in a `Net` module. 11 | 12 | > Future warning 13 | > 14 | > The current denoising network implementations are not necessarily the most efficient or the most effective and are bound to change in a future release. They do, however, provide a great starting point for experimentation. 15 | 16 | ## U-Net 17 | 18 | U-Net implementation adapted from the [The Annotated Diffusion Model](https://huggingface.co/blog/annotated-diffusion). It takes an input with shape `[b, c, h, w]` and returns an output with shape `[p, b, c, h, w]`. 19 | 20 | ### Parameters 21 | 22 | - `channels` -> Sequence of integers representing the number of channels in each layer of the U-Net. 23 | - `labels` (default `0`) -> Number of unique labels in $y$. 24 | - `parameters` (default `1`) -> Number of output parameters `p`. 25 | - `hidden` (default `256`) -> Hidden dimension. 26 | - `heads` (default `8`) -> Number of attention heads. 27 | - `groups` (default `16`) -> Number of groups in the group normalization layers. 28 | 29 | ### Example 30 | 31 | ```python 32 | from diffusion.net import UNet 33 | 34 | net = UNet(channels=(3, 64, 128, 256), labels=10) 35 | ``` 36 | 37 | ## Transformer 38 | 39 | Transformer implementation adapted from the [Peebles & Xie (2022) 40 | ](https://arxiv.org/abs/2212.09748) (adaptive layer norm block). It takes an input with shape `[b, l, e]` and returns an output with shape `[p, b, l, e]`. 41 | 42 | ### Parameters 43 | 44 | - `input` -> Input embedding dimension `e`. 45 | - `labels` (default `0`) -> Number of unique labels in $y$. 46 | - `parameters` (default `1`) -> Number of output parameters `p`. 47 | - `depth` (default `256`) -> Number of transformer blocks. 48 | - `width` (default `256`) -> Hidden dimension. 49 | - `heads` (default `8`) -> Number of attention heads. 50 | 51 | ### Example 52 | 53 | ```python 54 | from diffusion.net import Transformer 55 | 56 | net = Transformer(input=x.shape[2]) 57 | ``` 58 | 59 | -------------------------------------------------------------------------------- /docs/src/pages/modules/diffusion-model.mdx: -------------------------------------------------------------------------------- 1 | --- 2 | id: 2.1 3 | title: "Diffusion Model" 4 | index: true 5 | --- 6 | 7 | # {frontmatter.title} 8 | 9 | In Modular Diffusion, the `Model` class is a high-level interface that allows you to easily design and train your own custom Diffusion Models. It acts essentially as a container for all the modules that make up a Diffusion Model. 10 | 11 | ### Parameters 12 | 13 | - `data` -> Data transform module. 14 | - `schedule` -> Noise schedule module. 15 | - `noise` -> Noise type module. 16 | - `net` -> Denoising network module. 17 | - `loss` -> Loss function module. 18 | - `guidance` (Default: `None`) -> Optional guidance module. 19 | - `optimizer` (Default: `partial(Adam, lr=1e-4)`) -> Pytorch optimizer constructor function. 20 | - `device` (Default: `"cpu"`) -> Device to train the model on. 21 | - `compile` (Default: `true`) -> Whether to compile the model with `torch.compile` for faster training. 22 | 23 | ### Example 24 | ```python 25 | import diffusion 26 | from diffusion.data import Identity 27 | from diffusion.guidance import ClassifierFree 28 | from diffusion.loss import Simple 29 | from diffusion.net import UNet 30 | from diffusion.noise import Gaussian 31 | from diffusion.schedule import Cosine 32 | from torch.optim import AdamW 33 | from functools import partial 34 | 35 | model = diffusion.Model( 36 | data=Identity(x, y, batch=128, shuffle=True), 37 | schedule=Cosine(steps=1000), 38 | noise=Gaussian(parameter="epsilon", variance="fixed"), 39 | net=UNet(channels=(1, 64, 128, 256), labels=10), 40 | loss=Simple(parameter="epsilon"), 41 | guidance=ClassifierFree(dropout=0.1, strength=2), 42 | optimizer=partial(AdamW, lr=3e-4), 43 | device="cuda" if torch.cuda.is_available() else "cpu", 44 | ) 45 | ``` 46 | 47 | ## Train the model 48 | 49 | `Model.train` trains the model for a specified number of epochs. It **returns a generator** that yields the current loss when each epoch is finished, allowing the user to easily **validate the model between epochs** inside a `for` loop. 50 | 51 | ### Parameters 52 | 53 | - `epochs` (default: `1`) -> Number of epochs to train the model. 54 | - `progress` (default: `True`) -> Whether to display a progress bar for each epoch. 55 | 56 | ### Examples 57 | 58 | ```python 59 | # Train model without validation 60 | losses = [*model.train(epochs=100)] 61 | ``` 62 | 63 | ```python 64 | # Train model with validation 65 | for epoch, loss in enumerate(model.train(epochs=100)): 66 | if epoch % 10 == 0: 67 | # Validate your model here 68 | model.save("model.pt") 69 | ``` 70 | 71 | ## Sample from the model 72 | 73 | `Model.sample` samples from the model for a specified batch size and label tensor. It returns a tensor with shape `[t, b, ...]` where `t` is the number of time steps, `b` is the batch size, and `...` represents the shape of the data. This allows the user to **visualize the sampling process**. 74 | 75 | ### Parameters 76 | 77 | - `y` (default: `None`) -> Optional label tensor $y$ to condition sampling. 78 | - `batch` (default: `1`) -> Number of samples to generate. If `y` is not None, this is the number of samples per label. 79 | - `progress` (default: `True`) -> Whether to display a progress bar. 80 | 81 | ### Examples 82 | 83 | ```python 84 | # Save only final sampling results 85 | *_, z = model.sample(batch=10) 86 | ``` 87 | 88 | ```python 89 | # Save entire sampling process 90 | z = model.sample(batch=10) 91 | ``` 92 | 93 | ## Load the model 94 | 95 | `Model.load` loads the model's trainable weights from a file. The model should be initialized with **the same trainable modules it was initially trained with**. If a trainable module is replaced with a different module, the model **will not load correctly**. 96 | 97 | ### Parameters 98 | 99 | - `path` -> Path to the file containing the model's weights. 100 | 101 | ### Example 102 | 103 | ```python 104 | import diffusion 105 | from pathlib import Path 106 | 107 | model = diffusion.Model(...) 108 | if Path("model.pt").exists() 109 | model.load("model.pt") 110 | ``` 111 | 112 | ## Save the model 113 | 114 | `Model.save` saves the model's trainable weights to a file. 115 | 116 | ### Parameters 117 | 118 | - `path` -> Path to the file to save the model's weights to. 119 | 120 | ### Example 121 | 122 | ```python 123 | model.save("model.pt") 124 | ``` -------------------------------------------------------------------------------- /docs/src/pages/modules/guidance.mdx: -------------------------------------------------------------------------------- 1 | --- 2 | id: 2.9 3 | title: "Guidance" 4 | index: true 5 | --- 6 | 7 | # {frontmatter.title} 8 | 9 | In Diffusion Models, guidance mechanisms control how much importance the model gives to the conditioning information, at the cost of sample diversity. The two most prevalent forms of guidance are **Classifier Guidance** and **Classifier-Free Guidance**. As of right now, Modular Diffusion only ships with the latter, **but will support both in an upcoming release.** 10 | 11 | ## Classifier-free guidance 12 | 13 | Classifier-free guidance was introduced in [Ho & Salimans. (2022)](https://arxiv.org/abs/2207.12598) where it was found to produce higher fidelity samples in **conditional** Diffusion Models. It modifies the diffusion process as follows: 14 | 15 | - During **training**, a random subset of the batch labels are dropped, i.e., replaced with 0, before each epoch. 16 | - During **sampling**, predicted values $\hat{x}_\theta$ are computed according to $\hat{x}_\theta = (1 + s)\cdot\hat{x}_\theta(x_t|y) - s\cdot\hat{x}_\theta(x_t|0)$ 17 | 18 | where $s$ is a scalar parameter that controls the strength of the guidance signal. 19 | 20 | ### Parameters 21 | 22 | - `dropout` -> Percentage of labels dropped during training. 23 | - `strength` -> Strength of the guidance signal $s$. 24 | 25 | ### Example 26 | 27 | ```python 28 | from diffusion.guidance import ClassifierFree 29 | 30 | guidance = ClassifierFree(dropout=0.1, strength=2) 31 | ``` 32 | 33 | ## Classifier guidance 34 | 35 | *This guidance module is currently in development.* 36 | -------------------------------------------------------------------------------- /docs/src/pages/modules/loss-function.mdx: -------------------------------------------------------------------------------- 1 | --- 2 | id: 2.8 3 | title: "Loss Function" 4 | index: true 5 | --- 6 | 7 | # {frontmatter.title} 8 | 9 | The loss function of the denoising network seems to play a crucial role in the quality of the samples generated by Diffusion Models. Modular Diffusion ships with the reoccurring $L_\text{simple}$ and $L_\text{vlb}$ functions, as well as a `Lambda` utility to build your own custom loss function. 10 | 11 | > Hybrid losses 12 | > 13 | > To create a hybrid loss, simply add different loss modules together with a weight. For instance, to create a loss function that is a combination of $L_\text{simple}$ and $L_\text{vlb}$, you could write `loss = Simple() + 0.001 * VLB()`. 14 | 15 | ## Training batch 16 | 17 | While not a loss module, the `Batch` object is a fundamental component of Modular Diffusion. It is used to store the data that is fed to the loss module during training. When creating custom loss modules, it is important to know the names used to refer to the different tensors stored in the `Batch` object, listed below. 18 | 19 | ### Properties 20 | 21 | - `w` -> Initial data tensor $w$. 22 | - `x` -> Data tensor after transform $x_0$. 23 | - `y` -> Label tensor $y$. 24 | - `t` -> Time step tensor $t$. 25 | - `epsilon` -> Noise tensor $\epsilon$. May be `None` for certain noise types. 26 | - `z` -> Latent tensor $x_t$. 27 | - `hat` -> Predicted tensor $\hat{x}_\theta$, $\hat{\epsilon}_\theta$, or other(s) depending on the parametrization. 28 | - `q` -> Posterior distribution $q(x_{t-1}|x_t, x_0)$. 29 | - `p` -> Approximate posterior distribution $p_\theta(x_{t-1} | x_t)$. 30 | 31 | ## Lambda function 32 | 33 | Custom loss module that is defined using a lambda function and parametrized with a distribution. It is meant to be used as shorthand for writing a custom loss function class. 34 | 35 | ### Parameters 36 | 37 | - `function` -> Callable which receives a `Batch` object and returns a `Tensor` containing the loss value. 38 | 39 | ### Example 40 | 41 | ```python 42 | from diffusion.loss import Lambda 43 | from diffusion.distribution import Normal as N 44 | 45 | loss = Lambda[N](lambda b: ((b.q.mu - b.p.mu)**2).mean()) 46 | ``` 47 | 48 | > Type checking 49 | > 50 | > If you are using a type checker or want useful intellisense, you will need to explicitly parametrize the `Lambda` class with a `Distribution` type as seen in the example. 51 | 52 | ## Simple loss function 53 | 54 | Simple MSE loss introduced by [Ho et al. (2020)](https://arxiv.org/abs/2006.11239) in the context of Diffusion Models. Depending on the parametrization, it is defined as: 55 | 56 | - $L_\text{simple}=\mathbb{E}\left[\lvert\lvert x-\hat{x}_\theta\rvert\rvert^2\right]$ 57 | - $L_\text{simple}=\mathbb{E}\left[\lvert\lvert\epsilon-\hat{\epsilon}_\theta\rvert\rvert^2\right]$. 58 | 59 | ### Parameters 60 | 61 | - `parameter` (default `"x"`) -> Parameter to be learned and used to compute the loss. Either `"x"` ($\hat{x}_\theta$) or `"epsilon"` ($\hat{\epsilon}_\theta$). 62 | - `index` (default `0`) -> Index of the `hat` tensor which corresponds to the selected `parameter`. 63 | 64 | > Parametrization 65 | > 66 | > If you have the option, always remember to select the same parameter both in your model's `Noise` and `Loss` objects. 67 | 68 | ### Example 69 | 70 | ```python 71 | from diffusion.loss import Simple 72 | 73 | loss = Simple(parameter="epsilon") 74 | ``` 75 | 76 | ## Variational lower bound 77 | 78 | In the context of Diffusion Models, the variational lower bound (VLB) of $\log p(x_0)$ is given by: 79 | 80 | $$\begin{aligned}L_\text{vlb} & = \mathbb{E}_{q(x_{1}|x_0)}\left[\log p_{\theta}(x_0|x_1)\right] \\ & - \sum_{t=2}^{T} \mathbb{E}_{q(x_{t}|x_0)}\left[D_{KL}(q(x_{t-1}|x_t, x_0)||p_{\theta}(x_{t-1}|x_t))\right] \\ & - D_{KL}(q(x_T|x_0)||p(x_T))\text{,}\end{aligned}$$ 81 | 82 | where $D_{KL}(q(x_T|x_0)||p(x_T))$ is considered to be equal to 0 under standard assumptions. 83 | 84 | ### Parameters 85 | 86 | *This module has no parameters.* 87 | 88 | ### Example 89 | 90 | ```python 91 | from diffusion.loss import VLB 92 | 93 | loss = VLB() 94 | ``` 95 | -------------------------------------------------------------------------------- /docs/src/pages/modules/noise-schedule.mdx: -------------------------------------------------------------------------------- 1 | --- 2 | id: 2.4 3 | title: "Noise Schedule" 4 | index: true 5 | --- 6 | 7 | # {frontmatter.title} 8 | 9 | In Diffusion Models, the noise schedule dictates how much noise is added to the data at each time step. The noise schedule is typically defined as a function $\alpha_t$ that maps a time step $t$ into a value $\alpha_t \in [0, 1]$. Modular Diffusion comes with a growing set of prebuilt noise schedules. 10 | 11 | ## Constant schedule 12 | 13 | Constant noise schedule given by $\alpha_t = k$. 14 | 15 | ### Parameters 16 | 17 | - `steps` -> Number of time steps $T$. 18 | - `value` -> Constant value $k$. 19 | 20 | ### Example 21 | 22 | ```python 23 | from diffusion.schedule import Constant 24 | 25 | schedule = Constant(1000, 0.995) 26 | ``` 27 | 28 | ### Visualization 29 | 30 | Applying `Gaussian` noise to an image using the `Constant` schedule with $T=1000$ and $k=0.995$ in equally spaced snapshots: 31 | 32 | ![Image of a dog getting noisier at a constant rate.](/modular-diffusion/images/modules/noise-schedule/constant.png) 33 | 34 | ## Linear schedule 35 | 36 | Linear noise schedule introduced in [Ho et al. (2020)](https://arxiv.org/abs/2006.11239) computed by linearly interpolating from $\alpha_0$ to $\alpha_T$. 37 | 38 | ### Parameters 39 | 40 | - `steps` -> Number of time steps $T$. 41 | - `start` -> Start value $\alpha_0$. 42 | - `end` -> End value $\alpha_T$. 43 | 44 | ### Example 45 | 46 | ```python 47 | from diffusion.schedule import Linear 48 | 49 | schedule = Linear(1000, 0.9999, 0.98) 50 | ``` 51 | 52 | ### Visualization 53 | 54 | Applying `Gaussian` noise to an image using the `Linear` schedule with $T=1000$, $\alpha_0=0.9999$ and $\alpha_T=0.98$ in equally spaced snapshots: 55 | 56 | ![Image of a dog getting noisier at a linear rate.](/modular-diffusion/images/modules/noise-schedule/linear.png) 57 | 58 | ## Cosine schedule 59 | 60 | Cosine noise schedule introduced in [Nichol et al. (2021)](https://arxiv.org/abs/2102.12092) which offers a more gradual noising process relative to the linear schedule. It is defined as $\alpha_t = \frac{\bar{\alpha}_t}{\bar{\alpha}_{t-1}}$, where: 61 | 62 | - $\bar{\alpha}_t=\frac{f(t)}{f(0)}$ 63 | - $f(t) = \cos(\frac{t/T+s}{1+s} \cdot \frac{\pi}{2})^e$ 64 | 65 | ### Parameters 66 | 67 | - `steps` -> Number of time steps $T$. 68 | - `offset` (default: `8e-3`) -> Offset $s$. 69 | - `exponent` (default: `2`) -> Exponent $e$. 70 | 71 | ### Example 72 | 73 | ```python 74 | from diffusion.schedule import Cosine 75 | 76 | schedule = Cosine(1000) 77 | ``` 78 | 79 | ### Visualization 80 | 81 | Applying `Gaussian` noise to an image using the `Cosine` schedule with $T=1000$, $s=8e-3$ and $e=2$ in equally spaced snapshots: 82 | 83 | ![Image of a dog getting noisier at a cosine rate.](/modular-diffusion/images/modules/noise-schedule/cosine.png) 84 | 85 | ## Square root schedule 86 | 87 | Square root noise schedule introduced in [Li et al. (2022)](https://arxiv.org/abs/2110.03895). It is defined as $\alpha_t = \frac{\bar{\alpha}_t}{\bar{\alpha}_{t-1}}$, where $\bar{\alpha}_t=1-\sqrt{t/T+s}$. 88 | 89 | ### Parameters 90 | 91 | - `steps` -> Number of time steps $T$. 92 | - `offset` (default: `8e-3`) -> Offset $s$. 93 | 94 | ### Example 95 | 96 | ```python 97 | from diffusion.schedule import Sqrt 98 | 99 | schedule = Sqrt(1000) 100 | ``` 101 | 102 | ### Visualization 103 | 104 | Applying `Gaussian` noise to an image using the `Sqrt` schedule with $T=1000$ and $s=8e-3$ in equally spaced snapshots: 105 | 106 | ![Image of a dog getting noisier at a sqrt rate.](/modular-diffusion/images/modules/noise-schedule/sqrt.png) 107 | -------------------------------------------------------------------------------- /docs/src/pages/modules/noise-type.mdx: -------------------------------------------------------------------------------- 1 | --- 2 | id: 2.6 3 | title: "Noise Type" 4 | index: true 5 | --- 6 | 7 | # {frontmatter.title} 8 | 9 | In Diffusion Models, a noise type defines a specific parametrization of the stationary, prior, posterior, and approximate posterior distributions, $q(x_{T})$, $q(x_{t}|x_{0})$, $q(x_{t-1}|x_{t},x_{0})$, and $p_\theta(x_{t-1} | x_t)$, respectively. Modular Diffusion includes the standard `Gaussian` noise parametrization, as well as a few more noise types. 10 | 11 | ## Gaussian noise 12 | 13 | Gaussian noise model introduced in [Ho et al. (2020)](https://arxiv.org/abs/2006.11239), for which the diffusion process is defined as: 14 | 15 | - $q(x_{T})=\mathcal{N}(x_T; 0, \text{I})$ 16 | - $q(x_{t}|x_{0})=\mathcal{N}(x_{t};\sqrt{\bar{\alpha}_{t}}x_{t-1},(1 - \bar{\alpha}_{t})\text{I})$ 17 | - $q(x_{t-1}|x_{t},x_{0})=\mathcal{N}(x_{t};\frac{\sqrt{\alpha_t}(1-\bar\alpha_{t-1})x_{t} + \sqrt{\bar\alpha_{t-1}}(1-\alpha_t)x_0}{1 -\bar\alpha_{t}},\frac{(1 - \alpha_t)(1 - \bar\alpha_{t-1})}{1 -\bar\alpha_{t}}\text{I})$ 18 | - $p_\theta(x_{t-1} | x_t) = \mathcal{N}(x_{t};\hat{\mu}_\theta,\frac{(1 - \alpha_t)(1 - \bar\alpha_{t-1})}{1 -\bar\alpha_{t}}\text{I})$, 19 | 20 | where, depending on the parametrization: 21 | 22 | - $\hat{\mu}_\theta = \frac{\sqrt{\alpha_t}(1-\bar\alpha_{t-1})x_{t} + \sqrt{\bar\alpha_{t-1}}(1-\alpha_t)\hat{x}_\theta}{1 -\bar\alpha_{t}}$ 23 | - $\hat{\mu}_\theta = \frac{1}{\sqrt{\alpha_t}}x_t - \frac{1 - \alpha_t}{\sqrt{1 - \bar\alpha_t}\sqrt{\alpha_t}}\hat{\epsilon}_\theta$. 24 | 25 | ### Parameters 26 | 27 | - `parameter` (default `"x"`) -> Parameter to be learned and used to compute $\hat{\mu}_\theta$. If `"x"` ($\hat{x}_\theta$) or `"epsilon"` ($\hat{\epsilon}_\theta$) are chosen, $\hat{\mu}_\theta$ is computed using one of the formulas above. Selecting `"mu"` means that $\hat{\mu}_\theta$ is predicted directly. Typically, authors find that learning $\hat{\epsilon}_\theta$ leads to better results. 28 | - `variance` (default `"fixed"`) -> If `"fixed"`, the variance of $p_\theta(x_{t-1} | x_t)$ is fixed to $\frac{(1 - \alpha_t)(1 - \bar\alpha_{t-1})}{1 -\bar\alpha_{t}}\text{I}$. If `"learned"`, the variance is learned as a parameter of the model. 29 | 30 | > Parametrization 31 | > 32 | > If you have the option, always remember to select the same parameter both in your model's `Noise` and `Loss` objects. 33 | 34 | ### Example 35 | 36 | ```python 37 | from diffusion.noise import Gaussian 38 | 39 | noise = Gaussian(parameter="epsilon", variance="fixed") 40 | ``` 41 | 42 | ### Visualization 43 | 44 | Applying `Gaussian` noise to an image using the `Cosine` schedule with $T=1000$, $s=8e-3$ and $e=2$ in equally spaced snapshots: 45 | 46 | ![Image of a dog gradually turning noisy.](/modular-diffusion/images/modules/noise-type/gaussian.png) 47 | 48 | ## Uniform categorical noise 49 | 50 | Uniform categorical noise model introduced in [Austin et al. (2021)](https://arxiv.org/abs/2107.03006). In each time step, each token either stays the same or transitions to a different state. The noise type is defined by: 51 | 52 | - $q(x_T) = \mathrm{Cat}(x_T; \frac{\mathbb{1}\mathbb{1}^T}{K})$ 53 | - $q(x_t | x_0) = \mathrm{Cat}(x_t; x_0\overline{Q}_t)$ 54 | - $q(x_{t-1}|x_t, x_0) = \mathrm{Cat}\left(x_{t-1}; \frac{x_t Q_t^{\top} \odot x_0 \overline{Q}_{t-1}}{x_0 \overline{Q}_t x_t^\top}\right)$ 55 | - $p_\theta(x_{t-1} | x_t) = \mathrm{Cat}\left(x_{t-1}; \frac{x_t Q_t^{\top} \odot \hat{x}_\theta \overline{Q}_{t-1}}{\hat{x}_\theta \overline{Q}_t x_t^\top}\right)$, 56 | 57 | where: 58 | 59 | - $\mathbb{1}$ is a column vector of ones of length $k$. 60 | - $Q_t = \alpha_t \text{I} + (1 - \alpha_t) \mathbb{1}\mathbb{1}^T$ 61 | - $\overline{Q}_{t} = \bar{\alpha}_t \text{I} + (1 - \bar{\alpha}_t) \mathbb{1}\mathbb{1}^T$ 62 | 63 | > One-hot representation 64 | > 65 | > The `Uniform` noise type operates on one-hot vectors. To use it, you must use the `OneHot` data transform. 66 | 67 | ### Parameters 68 | 69 | - `k` -> Number of categories $k$. 70 | 71 | ### Example 72 | 73 | ```python 74 | from diffusion.noise import Uniform 75 | 76 | noise = Uniform(k=26) 77 | ``` 78 | 79 | ### Visualization 80 | 81 | Applying `Uniform` noise to an image with $k=255$ using the `Cosine` schedule with $T=1000$, $s=8e-3$ and $e=2$ in equally spaced snapshots: 82 | 83 | ![Image of a dog gradually turning noisy.](/modular-diffusion/images/modules/noise-type/uniform.png) 84 | 85 | ## Absorbing categorical noise 86 | 87 | Absorbing categorical noise model introduced in [Austin et al. (2021)](https://arxiv.org/abs/2107.03006). In each time step, each token either stays the same or transitions to an absorbing state. The noise type is defined by: 88 | 89 | - $q(x_T) = \mathrm{Cat}(x_T; \mathbb{1}e_m^T)$ 90 | - $q(x_t | x_0) = \mathrm{Cat}(x_t; x_0\overline{Q}_t)$ 91 | - $q(x_{t-1}|x_t, x_0) = \mathrm{Cat}\left(x_{t-1}; \frac{x_t Q_t^{\top} \odot x_0 \overline{Q}_{t-1}}{x_0 \overline{Q}_t x_t^\top}\right)$ 92 | - $p_\theta(x_{t-1} | x_t) = \mathrm{Cat}\left(x_{t-1}; \frac{x_t Q_t^{\top} \odot \hat{x}_\theta \overline{Q}_{t-1}}{\hat{x}_\theta \overline{Q}_t x_t^\top}\right)$, 93 | 94 | where 95 | 96 | - $\mathbb{1}$ is a column vector of ones of length $k$. 97 | - $e_m$ is a vector with a 1 on 98 | the absorbing state $m$ and 0 elsewhere. 99 | - $Q_t = \alpha_t \text{I} + (1 - \alpha_t) \mathbb{1}e_m^T$ 100 | - $\overline{Q}_{t} = \bar{\alpha}_t \text{I} + (1 - \bar{\alpha}_t) \mathbb{1}e_m^T$ 101 | 102 | > One-hot representation 103 | > 104 | > The `Absorbing` noise type operates on one-hot vectors. To use it, you must use the `OneHot` data transform. 105 | 106 | ### Parameters 107 | 108 | - `k` -> Number of categories $k$. 109 | - `m` -> Absorbing state $m$. 110 | 111 | ### Example 112 | 113 | ```python 114 | from diffusion.noise import Uniform 115 | 116 | noise = Absorbing(k=255, m=128) 117 | ``` 118 | 119 | ### Visualization 120 | 121 | Applying `Absorbing` noise to an image with $k=255$ and $m=128$ using the `Cosine` schedule with $T=1000$, $s=8e-3$ and $e=2$ in equally spaced snapshots: 122 | 123 | ![Image of a dog gradually turning gray.](/modular-diffusion/images/modules/noise-type/absorbing.png) 124 | -------------------------------------------------------------------------------- /docs/src/pages/modules/probability-distribution.mdx: -------------------------------------------------------------------------------- 1 | --- 2 | id: 2.5 3 | title: "Probability Distribution" 4 | index: true 5 | --- 6 | 7 | # {frontmatter.title} 8 | 9 | In Diffusion Models, the choice of a probability distribution plays a pivotal role in modeling the noise that guides transitions between time steps. While the `Distribution` type is not directly used to parametrize the `Model` class, it is used to create custom `Noise` and `Loss` modules. Modular Diffusion provides you with a set of distribution classes you can use to create your own modules. 10 | 11 | > Parameter shapes 12 | > 13 | > Distribution parameters are represented as tensors with the same size as a batch. This essentially means that a `Distribution` object functions as a collection of distributions, where each individual element in a batch corresponds to a unique distribution. For instance, in the case of a standard DDPM, each pixel in a batch of images is associated with its own `mu` and `sigma` values. 14 | 15 | ## Normal distribution 16 | 17 | Continuous probability distribution that is ubiquitously used in Diffusion Models. It has the following density function: 18 | 19 | $$f(x) = \frac{1}{\sqrt{2\pi\sigma^2}}\exp\left(-\frac{(x - \mu)^2}{2\sigma^2}\right)$$ 20 | 21 | Sampling from a normal distribution is denoted $$x \sim \mathcal{N}(\mu, \sigma^2)$$ and is equivalent to sampling from a standard normal distribution ($\mu = 0$ and $\sigma = 1$) and scaling the result by $\sigma$ and shifting it by $\mu$: 22 | 23 | - $\epsilon \sim \mathcal{N}(0, \text{I})$ 24 | - $x = \mu + \sigma \epsilon$ 25 | 26 | ### Parameters 27 | 28 | - `mu: Tensor` -> Mean tensor $\mu$. 29 | - `sigma: Tensor` -> Standard deviation tensor $\sigma$. Must have the same shape as `mu`. 30 | 31 | > Parametrization 32 | > 33 | > Please note that the `sigma` parameter does not correspond to the variance $\sigma^2$, but the standard deviation $\sigma$. 34 | 35 | ### Example 36 | 37 | ```python 38 | import torch 39 | from diffusion.distribution import Normal as N 40 | 41 | distribution = N(torch.zeros(3), torch.full((3,), 2)) 42 | x, epsilon = distribution.sample() 43 | # x = tensor([ 1.1053, 1.9027, -0.2554]) 44 | # epsilon = tensor([ 0.5527, 0.9514, -0.1277]) 45 | ``` 46 | 47 | ## Categorical distribution 48 | 49 | Discrete probability distribution that separately specifies the probability of each one of $k$ possible categories in a vector $p$. Sampling from a normal distribution is denoted $x \sim \text{Cat}(p)$. 50 | 51 | ### Parameters 52 | 53 | - `p: Tensor` -> Probability tensor $p$. All elements must be non-negative and sum to 1 in the last dimension. 54 | 55 | ### Example 56 | 57 | ```python 58 | import torch 59 | from diffusion.distribution import Categorical as Cat 60 | 61 | distribution = Cat(torch.tensor([[.1, .3, .6], [0, 0, 1]])) 62 | x, _ = distribution.sample() 63 | # x = tensor([[0., 1., 0.], [0., 0., 1.]]) 64 | ``` 65 | 66 | > Noise tensor 67 | > 68 | > The categorical distribution returns `None` in place of a noise tensor $\epsilon$, as it would have no meaningful interpretation. Therefore, you must ignore the second return value when sampling. 69 | -------------------------------------------------------------------------------- /docs/src/plugins/remark-layout.mjs: -------------------------------------------------------------------------------- 1 | export default () => { 2 | return (tree, file) => { 3 | file.data.astro.frontmatter.layout = "../../layouts/Layout.astro"; 4 | for (const node of tree.children) { 5 | if (node.type === "paragraph" && node.children?.length > 1) { 6 | node.children.push({ type: "mdxJsxFlowElement", name: "span" }); 7 | } 8 | } 9 | }; 10 | }; 11 | -------------------------------------------------------------------------------- /docs/src/styles/fonts.css: -------------------------------------------------------------------------------- 1 | @font-face { 2 | font-family: 'Inter'; 3 | src: url('/modular-diffusion/fonts/Inter/Inter-Thin.ttf') format('truetype'); 4 | font-weight: 100; 5 | font-style: normal; 6 | } 7 | 8 | @font-face { 9 | font-family: 'Inter'; 10 | src: url('/modular-diffusion/fonts/Inter/Inter-ExtraLight.ttf') format('truetype'); 11 | font-weight: 200; 12 | font-style: normal; 13 | } 14 | 15 | @font-face { 16 | font-family: 'Inter'; 17 | src: url('/modular-diffusion/fonts/Inter/Inter-Light.ttf') format('truetype'); 18 | font-weight: 300; 19 | font-style: normal; 20 | } 21 | 22 | @font-face { 23 | font-family: 'Inter'; 24 | src: url('/modular-diffusion/fonts/Inter/Inter-Regular.ttf') format('truetype'); 25 | font-weight: 400; 26 | font-style: normal; 27 | } 28 | 29 | @font-face { 30 | font-family: 'Inter'; 31 | src: url('/modular-diffusion/fonts/Inter/Inter-Medium.ttf') format('truetype'); 32 | font-weight: 500; 33 | font-style: normal; 34 | } 35 | 36 | @font-face { 37 | font-family: 'Inter'; 38 | src: url('/modular-diffusion/fonts/Inter/Inter-SemiBold.ttf') format('truetype'); 39 | font-weight: 600; 40 | font-style: normal; 41 | } 42 | 43 | @font-face { 44 | font-family: 'Inter'; 45 | src: url('/modular-diffusion/fonts/Inter/Inter-Bold.ttf') format('truetype'); 46 | font-weight: 700; 47 | font-style: normal; 48 | } 49 | 50 | @font-face { 51 | font-family: 'Inter'; 52 | src: url('/modular-diffusion/fonts/Inter/Inter-ExtraBold.ttf') format('truetype'); 53 | font-weight: 800; 54 | font-style: normal; 55 | } 56 | 57 | @font-face { 58 | font-family: 'Inter'; 59 | src: url('/modular-diffusion/fonts/Inter/Inter-Black.ttf') format('truetype'); 60 | font-weight: 900; 61 | font-style: normal; 62 | } 63 | 64 | @font-face { 65 | font-family: 'FiraCode'; 66 | src: url('/modular-diffusion/fonts/FiraCode/FiraCode-Light.ttf') format('truetype'); 67 | font-weight: 300; 68 | font-style: normal; 69 | } 70 | 71 | @font-face { 72 | font-family: 'FiraCode'; 73 | src: url('/modular-diffusion/fonts/FiraCode/FiraCode-Regular.ttf') format('truetype'); 74 | font-weight: 400; 75 | font-style: normal; 76 | } 77 | 78 | @font-face { 79 | font-family: 'FiraCode'; 80 | src: url('/modular-diffusion/fonts/FiraCode/FiraCode-Medium.ttf') format('truetype'); 81 | font-weight: 500; 82 | font-style: normal; 83 | } 84 | 85 | @font-face { 86 | font-family: 'FiraCode'; 87 | src: url('/modular-diffusion/fonts/FiraCode/FiraCode-SemiBold.ttf') format('truetype'); 88 | font-weight: 600; 89 | font-style: normal; 90 | } 91 | 92 | @font-face { 93 | font-family: 'FiraCode'; 94 | src: url('/modular-diffusion/fonts/FiraCode/FiraCode-Bold.ttf') format('truetype'); 95 | font-weight: 700; 96 | font-style: normal; 97 | } 98 | -------------------------------------------------------------------------------- /docs/tailwind.config.cjs: -------------------------------------------------------------------------------- 1 | const defaultTheme = require("tailwindcss/defaultTheme"); 2 | 3 | /** @type {import('tailwindcss').Config} */ 4 | module.exports = { 5 | content: ["./src/**/*.{astro,html,js,jsx,md,mdx,svelte,ts,tsx,vue}"], 6 | theme: { 7 | extend: { 8 | fontFamily: { 9 | sans: ["Inter", ...defaultTheme.fontFamily.sans], 10 | mono: ["FiraCode", ...defaultTheme.fontFamily.mono], 11 | }, 12 | }, 13 | }, 14 | plugins: [require('tailwindcss-opentype')], 15 | }; 16 | -------------------------------------------------------------------------------- /docs/tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "extends": "astro/tsconfigs/strict" 3 | } -------------------------------------------------------------------------------- /examples/conditional-diffusion.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | 4 | import torch 5 | from einops import rearrange 6 | from torchvision.datasets import MNIST 7 | from torchvision.transforms import ToTensor 8 | from torchvision.utils import save_image 9 | 10 | sys.path.append(".") 11 | 12 | import diffusion 13 | from diffusion.data import Identity 14 | from diffusion.guidance import ClassifierFree 15 | from diffusion.loss import Simple 16 | from diffusion.net import UNet 17 | from diffusion.noise import Gaussian 18 | from diffusion.schedule import Cosine 19 | 20 | file = Path(__file__) 21 | input = file.parent / "data/in" 22 | output = file.parent / "data/out" / file.stem 23 | output.mkdir(parents=True, exist_ok=True) 24 | torch.set_float32_matmul_precision("high") 25 | torch.set_grad_enabled(False) 26 | 27 | x, y = zip(*MNIST(str(input), transform=ToTensor(), download=True)) 28 | x, y = torch.stack(x) * 2 - 1, torch.tensor(y) + 1 29 | 30 | model = diffusion.Model( 31 | data=Identity(x, y, batch=128, shuffle=True), 32 | schedule=Cosine(steps=1000), 33 | noise=Gaussian(parameter="epsilon", variance="fixed"), 34 | net=UNet(channels=(1, 64, 128, 256), labels=10), 35 | guidance=ClassifierFree(dropout=0.1, strength=2), 36 | loss=Simple(parameter="epsilon"), 37 | device="cuda" if torch.cuda.is_available() else "cpu", 38 | ) 39 | 40 | if (output / "model.pt").exists(): 41 | model.load(output / "model.pt") 42 | epoch = sum(1 for _ in output.glob("[0-9]*")) 43 | 44 | for epoch, loss in enumerate(model.train(epochs=100), 1): 45 | z = model.sample(torch.arange(1, 11)) 46 | z = z[torch.linspace(0, z.shape[0] - 1, 10).int()] 47 | z = rearrange(z, "t b c h w -> c (b h) (t w)") 48 | z = (z + 1) / 2 49 | save_image(z, output / f"{epoch}-{loss:.2e}.png") 50 | model.save(output / "model.pt") -------------------------------------------------------------------------------- /examples/data/representative/in/afhq/flickr_dog_000083.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cabralpinto/modular-diffusion/4d919974fcf8ec5108f84122ce18e9a9ba46fd35/examples/data/representative/in/afhq/flickr_dog_000083.jpg -------------------------------------------------------------------------------- /examples/data/representative/in/afhq/flickr_dog_001159.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cabralpinto/modular-diffusion/4d919974fcf8ec5108f84122ce18e9a9ba46fd35/examples/data/representative/in/afhq/flickr_dog_001159.jpg -------------------------------------------------------------------------------- /examples/data/representative/in/afhq/pixabay_dog_000802.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cabralpinto/modular-diffusion/4d919974fcf8ec5108f84122ce18e9a9ba46fd35/examples/data/representative/in/afhq/pixabay_dog_000802.jpg -------------------------------------------------------------------------------- /examples/data/representative/in/afhq/pixabay_dog_003974.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cabralpinto/modular-diffusion/4d919974fcf8ec5108f84122ce18e9a9ba46fd35/examples/data/representative/in/afhq/pixabay_dog_003974.jpg -------------------------------------------------------------------------------- /examples/data/representative/in/afhq/pixabay_dog_004034.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cabralpinto/modular-diffusion/4d919974fcf8ec5108f84122ce18e9a9ba46fd35/examples/data/representative/in/afhq/pixabay_dog_004034.jpg -------------------------------------------------------------------------------- /examples/data/representative/in/e2e/0.txt: -------------------------------------------------------------------------------- 1 | The Rice Boat, in the riverside area, near Express by Holiday Inn, has English food, is kids friendly, has a high customer rating, and has a price range between 20 and 25 pounds. -------------------------------------------------------------------------------- /examples/data/representative/in/e2e/1.txt: -------------------------------------------------------------------------------- 1 | The Phoenix offers moderately priced fast food in the centre of the city. It has received 3 out of 5 customer rating. -------------------------------------------------------------------------------- /examples/data/representative/in/e2e/2.txt: -------------------------------------------------------------------------------- 1 | For a high-end coffee shop with high ratings in riverside, you should check out The Vaults near Café Brazil. -------------------------------------------------------------------------------- /examples/data/representative/in/e2e/3.txt: -------------------------------------------------------------------------------- 1 | A cheap pub The Plough is located near Café Rouge. It is not family friendly. -------------------------------------------------------------------------------- /examples/data/representative/in/e2e/4.txt: -------------------------------------------------------------------------------- 1 | The Cambridge Blue, located near the Café Brazil, is a pub with food under £20. -------------------------------------------------------------------------------- /examples/data/representative/in/mnist/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cabralpinto/modular-diffusion/4d919974fcf8ec5108f84122ce18e9a9ba46fd35/examples/data/representative/in/mnist/1.png -------------------------------------------------------------------------------- /examples/data/representative/in/mnist/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cabralpinto/modular-diffusion/4d919974fcf8ec5108f84122ce18e9a9ba46fd35/examples/data/representative/in/mnist/2.png -------------------------------------------------------------------------------- /examples/data/representative/in/mnist/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cabralpinto/modular-diffusion/4d919974fcf8ec5108f84122ce18e9a9ba46fd35/examples/data/representative/in/mnist/3.png -------------------------------------------------------------------------------- /examples/data/representative/in/mnist/6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cabralpinto/modular-diffusion/4d919974fcf8ec5108f84122ce18e9a9ba46fd35/examples/data/representative/in/mnist/6.png -------------------------------------------------------------------------------- /examples/data/representative/in/mnist/7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cabralpinto/modular-diffusion/4d919974fcf8ec5108f84122ce18e9a9ba46fd35/examples/data/representative/in/mnist/7.png -------------------------------------------------------------------------------- /examples/data/representative/out/conditional-diffusion/1-7.31e-02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cabralpinto/modular-diffusion/4d919974fcf8ec5108f84122ce18e9a9ba46fd35/examples/data/representative/out/conditional-diffusion/1-7.31e-02.png -------------------------------------------------------------------------------- /examples/data/representative/out/conditional-diffusion/2-3.91e-02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cabralpinto/modular-diffusion/4d919974fcf8ec5108f84122ce18e9a9ba46fd35/examples/data/representative/out/conditional-diffusion/2-3.91e-02.png -------------------------------------------------------------------------------- /examples/data/representative/out/conditional-diffusion/20-3.56e-02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cabralpinto/modular-diffusion/4d919974fcf8ec5108f84122ce18e9a9ba46fd35/examples/data/representative/out/conditional-diffusion/20-3.56e-02.png -------------------------------------------------------------------------------- /examples/data/representative/out/conditional-diffusion/4-4.54e-02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cabralpinto/modular-diffusion/4d919974fcf8ec5108f84122ce18e9a9ba46fd35/examples/data/representative/out/conditional-diffusion/4-4.54e-02.png -------------------------------------------------------------------------------- /examples/data/representative/out/conditional-diffusion/7-4.03e-02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cabralpinto/modular-diffusion/4d919974fcf8ec5108f84122ce18e9a9ba46fd35/examples/data/representative/out/conditional-diffusion/7-4.03e-02.png -------------------------------------------------------------------------------- /examples/data/representative/out/embedding-diffusion/293 (4.36e-03).txt: -------------------------------------------------------------------------------- 1 | imate approval areui heart beginare spacehoenix bite calm amenitiesdle better onk bathroomgs Fitzbillies go Exp overpriced Moderateess lack sho what fruit Look payrecoruits 2 | imed drinks.uitearea far atrociousact spen calledance GiraffeRruitsare whetended mark bus steak Fitzbillies2 fruit north tag Fitzbillies Ha guestcompan joint sit burger 3 | ick drinksarea w25£ riverult tagra called satisfying thing guaranteeutare guestlessibi Consumer Ah f50.rackeui worse Typically guest guest visith Euro 4 | treatriesll badvery mak Fitzbillies tag views toast break pleas trip 5,. misstanding calm Consumer w fruit everyone medi ratingle affordabl Name Nameh based 5 | There is a roadvery mak price pay spen toast plateg friend parties.GBPless calmr w grubrackeloselestance Si rank average 6 | There is a Specializvery mak price payvailable called 'ended alternative city.GBPless calm upward wguite averageGBP Si 7 | There is a Specializvery super price pay enjoyed called The Ric alternative Boat.GBP desserts wwelcoming Star 8 | There is a verycond family place located called The Rice Boat. mark Name 9 | There is a very finest family place located called The Rice Boat. 10 | There is a very affordable family place located called The Rice Boat. -------------------------------------------------------------------------------- /examples/data/representative/out/embedding-diffusion/490 (2.41e-02).txt: -------------------------------------------------------------------------------- 1 | Riverside eaten bitesnownfini problem exist Centrmazing mea health Sitstanding pro £20-2 upward shrimpeaturricing atmosphereui group Ra scorepe guestunchuiuntesstance 2 | eepmazingmb cl assort biteinally couplesgains couples Chinesekid more des stayriverside specialty opinion refreshment fries College leave Overaverage guest class pricesies f guest parton 3 | aw anymbienceest Japanese any payertain include wheries moreruits fruits Recent chips Familiesstanding orientalunt guest chain In noctionies f guest kitchen 4 | class dr Innunt £ landmark winery near baseduitenearankrban entireracke road chips reason gain guest heart50. comensalesnot guest ambiancectionwelcoming f pasta du 5 | famibi In sta pasta landmark gener nearfamily providenearended eveningunt delivers entire chips than gain water50.Tnotnot puwelcoming f £ 6 | Bibi In town Sushi, located near winery provide plateank earthand food entiretmospherere grub Offer regardedrian won f50. 7 | Bibimbtaking House, located near Cla provide plate, beatand food for less than £20 heart noodles50. 8 | Bibimb guest House, located near Clare fruit, beat French food for less than £20 9 | Bibimbap House, located near Clare Hall, serves French food for less than £20h 10 | Bibimbap House, located near Clare Hall, serves French food for less than £20 -------------------------------------------------------------------------------- /examples/data/representative/out/embedding-diffusion/5 (1.63e-01).txt: -------------------------------------------------------------------------------- 1 | gra pounds upwardancespir take plate noodles du focuse fries suc 20.00 experienclong partiesunfriendly soluiet Bells pasta oriented du average sc standards20, attract rankwelcoming Punternjoy 2 | Lethoppingrts night in du Recent ahead parties man You grapes pre Whi 5,vers pasta recommend category surearea20.£20 was was themedsportended Fr Near deals oriented 3 | racke fruitmeuiteendedcauselocated calmstroact standards fsta gener buck refreshment pastarange othernds favour20. relativeCrown rangedtmosphereh averageeststance includ20, 4 | AAromi gueststaendedrspaghetti starts is Ra The chain of gener buck Kid - wended pasta 20-25. varieCrown visith noiseuntuntRa guest 5 | A attract baruiteZizzi Idealpaghetti starts is Raick waterman of uperfect pastag guest guest pasta 20-25. guestlose guest costs tRah 6 | sensibl attractmeuite are towunfriendly Two is a which waterman of genererfectg pasta Specializlose guest affordabl 7 | Ac spenuite Ha tow here Two isimateick suit of gener this Specializ guest was 8 | A restaurant that near When expectwenty Two is aick suit of, achiev 9 | A restaurant that near The customerswenty Two is aick suit of, fruit 10 | A restaurant that near The Eaglewenty Two is a good out of, Italian -------------------------------------------------------------------------------- /examples/data/representative/out/embedding-diffusion/51 (5.72e-02).txt: -------------------------------------------------------------------------------- 1 | guest friend bad whetrban whatwithmppl evening simarea guestugh liked community under scentre biteuick Is spen Hall Priceyelative pasta regular Near w Collegespecial 2 | College average somewh Fitzbillies class Gr toward river trip Recent outlet tag Of Grovelocatare In cho shrimp ca Adults....... tapa sure type space don spenfact couples block gener 3 | as what aly Familyecpounds inexpensive service beluga Japanese College pasta table guestare Jo choended feedback Adults Grub guests Sushiunch feed chargary couples pasta blocks 4 | busurni a Has entire joint Japanese guestspecialleui20. WelcomAromi part ambient Lowree cater Adultsab slight medi entire closecosts themed Ha guest herbs blocks 5 | 25£ended a winery spen high Lo Le coffee mediocre ratinguihouse affordablelocat choices located NearE feedbacke Boat outunch ratingience guest mediocre guest 6 | reason is a river spen high Moderately tradition coffee mediocreoffee towhouse Welcom city choices located Near The25£e Boat out based heart In rang 7 | lo is aly spen high 5 traditionended uni called tow at affordable city £30. located near The Rice Boat. rang 8 | There is a proud rated high 5 star coffee shop called Wildwood at affordable city centre located near The Rice Boat. 9 | There is aly rated high 5 star coffee shop called Wildwood at affordable city centre located near The Rice Boat. 10 | There is aly rated high 5 star coffee shop called Wildwood at affordable city centre located near The Rice Boat. -------------------------------------------------------------------------------- /examples/data/representative/out/embedding-diffusion/960 (1.87e-02).txt: -------------------------------------------------------------------------------- 1 | nown which su fam hear esca Cheal consumers25£ entire bathroomuite speak bite riverfrontp-£25 meacentre Thmin riverocick ra one did offer don£20 bus 2 | and Ra 3.33uite roadde inexpensive eatub any Som Vegeta, pasta tapa moment Hall appetizers al sample work Typically-25.ction trip starts based mile which anCrown 3 | responseOrientedshop su roaduipl Avalon ra dr cheaperrant Ideal, Ideallygs baby Chines activit al modern50. 20-25. pasta freshic starts Specializ Park guest blocks river 4 | spen consistent anwit coupleui ob Look food delivers cheaper lack Hotel, centrked Watermanated includ pasta Sospir pasta freshurther pasta Name drinks happ noodles 5 | and costumers an Andoffeeta chain Look foodother Crown Plazasit,and The Waterman Hall includertain facility pasta pasta guestted Chines river 6 | standardsOrientedick 30£ friendly venue On Look food thing Crown Plaza Hotel, called The Waterman. includ pasta heart50. noodles 7 | eat is man child friendly venue serving Look food Near Crown Plaza Hotel, called The Waterman. pu pasta guest average noodles 8 | There is a child friendly venue serving French food Near Crown Plaza worse, called The Waterman. noodles 9 | There is a child friendly venue serving French food near Crown Plaza Hotel, called The Waterman. 10 | There is a child friendly venue serving French food near Crown Plaza Hotel, called The Waterman. -------------------------------------------------------------------------------- /examples/data/representative/out/transformer-diffusion/26-2.42e-02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cabralpinto/modular-diffusion/4d919974fcf8ec5108f84122ce18e9a9ba46fd35/examples/data/representative/out/transformer-diffusion/26-2.42e-02.png -------------------------------------------------------------------------------- /examples/data/representative/out/transformer-diffusion/3650-7.06e-03.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cabralpinto/modular-diffusion/4d919974fcf8ec5108f84122ce18e9a9ba46fd35/examples/data/representative/out/transformer-diffusion/3650-7.06e-03.png -------------------------------------------------------------------------------- /examples/data/representative/out/transformer-diffusion/4986-7.32e-03.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cabralpinto/modular-diffusion/4d919974fcf8ec5108f84122ce18e9a9ba46fd35/examples/data/representative/out/transformer-diffusion/4986-7.32e-03.png -------------------------------------------------------------------------------- /examples/data/representative/out/transformer-diffusion/5002-6.01e-03.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cabralpinto/modular-diffusion/4d919974fcf8ec5108f84122ce18e9a9ba46fd35/examples/data/representative/out/transformer-diffusion/5002-6.01e-03.png -------------------------------------------------------------------------------- /examples/data/representative/out/transformer-diffusion/509-1.29e-01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cabralpinto/modular-diffusion/4d919974fcf8ec5108f84122ce18e9a9ba46fd35/examples/data/representative/out/transformer-diffusion/509-1.29e-01.png -------------------------------------------------------------------------------- /examples/data/representative/out/unconditional-diffusion/18-3.67e-02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cabralpinto/modular-diffusion/4d919974fcf8ec5108f84122ce18e9a9ba46fd35/examples/data/representative/out/unconditional-diffusion/18-3.67e-02.png -------------------------------------------------------------------------------- /examples/data/representative/out/unconditional-diffusion/2-5.50e-02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cabralpinto/modular-diffusion/4d919974fcf8ec5108f84122ce18e9a9ba46fd35/examples/data/representative/out/unconditional-diffusion/2-5.50e-02.png -------------------------------------------------------------------------------- /examples/data/representative/out/unconditional-diffusion/36-3.63e-02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cabralpinto/modular-diffusion/4d919974fcf8ec5108f84122ce18e9a9ba46fd35/examples/data/representative/out/unconditional-diffusion/36-3.63e-02.png -------------------------------------------------------------------------------- /examples/data/representative/out/unconditional-diffusion/42-3.40e-02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cabralpinto/modular-diffusion/4d919974fcf8ec5108f84122ce18e9a9ba46fd35/examples/data/representative/out/unconditional-diffusion/42-3.40e-02.png -------------------------------------------------------------------------------- /examples/data/representative/out/unconditional-diffusion/51-3.40e-02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cabralpinto/modular-diffusion/4d919974fcf8ec5108f84122ce18e9a9ba46fd35/examples/data/representative/out/unconditional-diffusion/51-3.40e-02.png -------------------------------------------------------------------------------- /examples/embedding-diffusion.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | 4 | import torch 5 | 6 | sys.path.append(".") 7 | sys.path.append("examples") 8 | 9 | from utils import download, tokenize 10 | 11 | import diffusion 12 | from diffusion.data import Embedding 13 | from diffusion.loss import Simple 14 | from diffusion.net import Transformer 15 | from diffusion.noise import Gaussian 16 | from diffusion.schedule import Sqrt 17 | 18 | file = Path(__file__) 19 | input = file.parent / "data/in/e2e" 20 | output = file.parent / "data/out" / file.stem 21 | output.mkdir(parents=True, exist_ok=True) 22 | torch.set_grad_enabled(False) 23 | torch.set_float32_matmul_precision("high") 24 | 25 | if not input.exists(): 26 | url = "https://raw.githubusercontent.com/tuetschek/e2e-dataset/master/trainset.csv" 27 | input.mkdir(parents=True) 28 | download(url, input / "text") 29 | text = (input / "text").read_text().replace("\n ", "").split("\n")[1:-1] 30 | text = "\n".join(line.rsplit('",', 1)[1][1:-1] for line in text) 31 | (input / "text").write_text(text) 32 | tokenize(input / "text", input / "ids", size=2048, pad=True) 33 | (input / "text").unlink() 34 | 35 | v = (input / "vocabulary").read_text().split()[::2] 36 | x = (input / "ids").read_text().split("\n") 37 | x = [[int(w) for w in l] for s in x if len(l := s.split()) <= 128] 38 | x = torch.tensor([s + [1] * (128 - len(s)) for s in x]) 39 | 40 | model = diffusion.Model( 41 | data=Embedding(x, k=len(v), d=32, batch=64, shuffle=True), 42 | schedule=Sqrt(2000), 43 | noise=Gaussian(parameter="x", variance="fixed"), 44 | loss=Simple(parameter="x"), 45 | net=Transformer(input=32, width=1024, depth=16, heads=16), 46 | device="cuda" if torch.cuda.is_available() else "cpu", 47 | ) 48 | 49 | if (output / "model.pt").exists(): 50 | model.load(output / "model.pt") 51 | epoch = sum(1 for _ in output.glob("[0-9]*")) 52 | 53 | for epoch, loss in enumerate(model.train(epochs=10000), epoch + 1): 54 | *_, z = model.sample(batch=10) 55 | z = ["".join(v[w] for w in s).replace("▁", " ").lstrip() for s in z.int()] 56 | (output / f"{epoch}-{loss:.2e}.txt").write_text("\n".join(z)) 57 | model.save(output / "model.pt") -------------------------------------------------------------------------------- /examples/text-diffusion.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | 4 | import torch 5 | from torch.nn import Softmax 6 | 7 | sys.path.append(".") 8 | sys.path.append("examples") 9 | 10 | from utils import download, tokenize 11 | 12 | import diffusion 13 | from diffusion.data import OneHot 14 | from diffusion.distribution import Categorical as Cat 15 | from diffusion.loss import VLB, Lambda 16 | from diffusion.net import Transformer 17 | from diffusion.noise import Absorbing 18 | from diffusion.schedule import Sqrt 19 | 20 | file = Path(__file__) 21 | input = file.parent / "data/in/e2e" 22 | output = file.parent / "data/out" / file.stem 23 | output.mkdir(parents=True, exist_ok=True) 24 | torch.set_grad_enabled(False) 25 | torch.set_float32_matmul_precision("high") 26 | 27 | if not input.exists(): 28 | url = "https://raw.githubusercontent.com/tuetschek/e2e-dataset/master/trainset.csv" 29 | input.mkdir(parents=True) 30 | download(url, input / "text") 31 | text = (input / "text").read_text().replace("\n ", "").split("\n")[1:-1] 32 | text = "\n".join(line.rsplit('",', 1)[1][1:-1] for line in text) 33 | (input / "text").write_text(text) 34 | tokenize(input / "text", input / "ids", size=2048, pad=True) 35 | (input / "text").unlink() 36 | 37 | v = (input / "vocabulary").read_text().split()[::2] + ["?"] 38 | x = (input / "ids").read_text().split("\n") 39 | x = [[int(w) for w in l] for s in x if len(l := s.split()) <= 128] 40 | x = torch.tensor([s + [1] * (128 - len(s)) for s in x]) 41 | 42 | model = diffusion.Model( 43 | data=OneHot(x, k=len(v), batch=32, shuffle=True), 44 | schedule=Sqrt(1000), 45 | noise=Absorbing(len(v)), 46 | loss=VLB() + 1e-2 * Lambda[Cat](lambda batch: Cat(batch.hat[0]).nll(batch.x).sum()), 47 | net=Transformer(input=len(v), width=1024, depth=16, heads=16) | Softmax(3), 48 | device="cuda" if torch.cuda.is_available() else "cpu", 49 | ) 50 | 51 | if (output / "model.pt").exists(): 52 | model.load(output / "model.pt") 53 | epoch = sum(1 for _ in output.glob("[0-9]*")) 54 | 55 | for epoch, loss in enumerate(model.train(epochs=10000), epoch + 1): 56 | z = model.sample() 57 | z = z[torch.linspace(0, z.shape[0] - 1, 10).int(), 0].int() 58 | z = ["".join(v[w] for w in s).replace("▁", " ").lstrip() for s in z] 59 | (output / f"{epoch}-{loss:.2e}.txt").write_text("\n".join(z)) 60 | model.save(output / "model.pt") 61 | -------------------------------------------------------------------------------- /examples/transformer-diffusion.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import sys 3 | from pathlib import Path 4 | 5 | import torch 6 | from einops import rearrange 7 | from torchvision.datasets import ImageFolder 8 | from torchvision.transforms import ToTensor 9 | from torchvision.transforms.functional import resize 10 | from torchvision.utils import save_image 11 | 12 | sys.path.append(".") 13 | sys.path.append("examples") 14 | 15 | from utils import download 16 | 17 | import diffusion 18 | from diffusion.data import Identity 19 | from diffusion.loss import Simple 20 | from diffusion.net import Transformer 21 | from diffusion.noise import Gaussian 22 | from diffusion.schedule import Cosine 23 | 24 | file = Path(__file__) 25 | input = file.parent / "data/in/afhq" 26 | input.parent.mkdir(parents=True, exist_ok=True) 27 | output = file.parent / "data/out" / file.stem 28 | output.mkdir(parents=True, exist_ok=True) 29 | torch.set_float32_matmul_precision("high") 30 | torch.set_grad_enabled(False) 31 | 32 | if not input.exists(): 33 | download("https://www.dropbox.com/s/t9l9o3vsx2jai3z/afhq.zip?dl=1", "afhq.zip") 34 | shutil.unpack_archive("afhq.zip", input.parent) 35 | Path("afhq.zip").unlink() 36 | (input / "dog").mkdir() 37 | for path in input.glob("*/dog/*"): 38 | path.rename(input / "dog" / path.name) 39 | for path in input / "train", input / "val": 40 | shutil.rmtree(path) 41 | 42 | c, h, w, p, q = 3, 64, 64, 2, 2 43 | x, _ = zip(*ImageFolder(str(input), ToTensor())) 44 | x = torch.stack(x) * 2 - 1 45 | x = resize(x, [h, w], antialias=False) 46 | x = rearrange(x, "b c (h p) (w q) -> b (h w) (c p q)", p=p, q=q) 47 | 48 | model = diffusion.Model( 49 | data=Identity(x, batch=16, shuffle=True), 50 | schedule=Cosine(1000), 51 | noise=Gaussian(parameter="epsilon", variance="fixed"), 52 | net=Transformer(input=x.shape[2], width=768, depth=12, heads=12), 53 | loss=Simple(parameter="epsilon"), 54 | device="cuda" if torch.cuda.is_available() else "cpu", 55 | ) 56 | 57 | if (output / "model.pt").exists(): 58 | model.load(output / "model.pt") 59 | epoch = sum(1 for _ in output.glob("[0-9]*")) 60 | 61 | for epoch, loss in enumerate(model.train(epochs=10000), epoch + 1): 62 | z = model.sample(batch=10) 63 | print(z[-1].min().item(), z[-1].max().item(), flush=True) 64 | z = z[torch.linspace(0, z.shape[0] - 1, 10).int()] 65 | z = rearrange(z, "t b (h w) (c p q) -> c (b h p) (t w q)", h=h // p, p=p, q=q) 66 | z = (z + 1) / 2 67 | save_image(z, output / f"{epoch}-{loss:.2e}.png") 68 | model.save(output / "model.pt") 69 | -------------------------------------------------------------------------------- /examples/unconditional-diffusion.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | 4 | import torch 5 | from einops import rearrange 6 | from torchvision.datasets import MNIST 7 | from torchvision.transforms import ToTensor 8 | from torchvision.utils import save_image 9 | 10 | sys.path.append(".") 11 | 12 | import diffusion 13 | from diffusion.data import Identity 14 | from diffusion.loss import Simple 15 | from diffusion.net import UNet 16 | from diffusion.noise import Gaussian 17 | from diffusion.schedule import Linear 18 | 19 | file = Path(__file__) 20 | input = file.parent / "data/in" 21 | output = file.parent / "data/out" / file.stem 22 | output.mkdir(parents=True, exist_ok=True) 23 | torch.set_float32_matmul_precision("high") 24 | torch.set_grad_enabled(False) 25 | 26 | x, _ = zip(*MNIST(str(input), transform=ToTensor(), download=True)) 27 | x = torch.stack(x) * 2 - 1 28 | 29 | model = diffusion.Model( 30 | data=Identity(x, batch=128, shuffle=True), 31 | schedule=Linear(1000, 0.9999, 0.98), 32 | noise=Gaussian(parameter="epsilon", variance="fixed"), 33 | net=UNet(channels=(1, 64, 128, 256)), 34 | loss=Simple(parameter="epsilon"), 35 | device="cuda" if torch.cuda.is_available() else "cpu", 36 | ) 37 | 38 | if (output / "model.pt").exists(): 39 | model.load(output / "model.pt") 40 | epoch = sum(1 for _ in output.glob("[0-9]*")) 41 | 42 | for epoch, loss in enumerate(model.train(epochs=1000), 1): 43 | z = model.sample(batch=10) 44 | z = z[torch.linspace(0, z.shape[0] - 1, 10).int()] 45 | z = rearrange(z, "t b c h w -> c (b h) (t w)") 46 | z = (z + 1) / 2 47 | save_image(z, output / f"{epoch}-{loss:.2e}.png") 48 | model.save(output / "model.pt") 49 | -------------------------------------------------------------------------------- /examples/utils/__init__.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | from pathlib import Path 3 | 4 | import requests 5 | from sentencepiece import SentencePieceProcessor, SentencePieceTrainer 6 | from tqdm import tqdm 7 | 8 | 9 | def download(url: str, path: Path | str) -> None: 10 | response = requests.get(url, stream=True) 11 | bar = tqdm( 12 | unit='B', 13 | unit_scale=True, 14 | unit_divisor=1024, 15 | miniters=1, 16 | total=int(response.headers.get('content-length', 0)), 17 | ) 18 | with open(path, "wb") as file: 19 | for chunk in response.iter_content(chunk_size=4096): 20 | file.write(chunk) 21 | bar.update(len(chunk)) 22 | bar.refresh() 23 | bar.close() 24 | 25 | 26 | def tokenize( 27 | input: Path | str, 28 | output: Path | str, 29 | size: int, 30 | pad: bool = False, 31 | ) -> None: 32 | input, output = Path(input), Path(output) 33 | SentencePieceTrainer.Train( 34 | input=input, 35 | model_prefix=output.parent / "_", 36 | vocab_size=size, 37 | normalization_rule_name='nfkc_cf', 38 | pad_id=1 if pad else -1, 39 | bos_id=-1, 40 | eos_id=-1, 41 | split_digits=True, 42 | input_sentence_size=1000000, 43 | shuffle_input_sentence=True, 44 | ) 45 | shutil.move(output.parent / "_.vocab", output.parent / "vocabulary") 46 | text = input.read_text().split("\n") 47 | sp = SentencePieceProcessor(str(output.parent / "_.model")) # type: ignore 48 | ids = "\n".join(" ".join(map(str, line)) for line in sp.Encode(text)) 49 | output.write_text(ids) 50 | (output.parent / "_.model").unlink() -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling"] 3 | build-backend = "hatchling.build" 4 | 5 | [project] 6 | name = "modular-diffusion" 7 | version = "0.0.3" 8 | authors = [{ name = "João Cabral Pinto", email = "jmcabralpinto@gmail.com" }] 9 | description = "Modular Diffusion" 10 | readme = "README.md" 11 | requires-python = ">=3.10" 12 | classifiers = [ 13 | "Programming Language :: Python :: 3", 14 | "License :: OSI Approved :: MIT License", 15 | "Operating System :: OS Independent", 16 | ] 17 | dependencies = [ 18 | "einops==0.6.1", 19 | "tqdm==4.64.1", 20 | "typing_extensions==4.7.1", 21 | ] 22 | 23 | [project.urls] 24 | "Homepage" = "https://github.com/cabralpinto/modular-diffusion" 25 | "Bug Tracker" = "https://github.com/cabralpinto/modular-diffusion/issues" 26 | -------------------------------------------------------------------------------- /pyrightconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "typeCheckingMode": "strict", 3 | "reportUnknownMemberType": "none", 4 | "reportUnknownArgumentType": "none", 5 | "reportUnknownVariableType": "none", 6 | "reportMissingTypeStubs": "none", 7 | "reportUnusedImport": "warning", 8 | } -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | einops==0.6.1 2 | tqdm==4.64.1 3 | typing_extensions==4.7.1 4 | --------------------------------------------------------------------------------