├── VERSION ├── data ├── raw │ └── .gitkeep └── README.md ├── src └── generativezoo │ ├── data │ ├── __init__.py │ └── CycleGAN_Dataloaders.py │ ├── models │ ├── __init__.py │ └── SM │ │ └── normalization.py │ ├── utils │ └── __init__.py │ ├── __init__.py │ ├── version.py │ ├── Text2Img_LoRA.py │ ├── MaskGit.py │ ├── VQGAN.py │ ├── SGM.py │ ├── RF.py │ ├── Text2Img_ControlNet.py │ ├── GPT.py │ ├── HVAE.py │ ├── CondDDPM.py │ ├── InstructPix2Pix.py │ ├── CondGan.py │ ├── config.py │ ├── CondFM.py │ ├── VanFlow.py │ ├── FM.py │ ├── CondVAE.py │ ├── DAE.py │ ├── P-CNN.py │ ├── FlowPP.py │ ├── RNVP.py │ ├── DDPM.py │ ├── DCGAN.py │ ├── VanVAE.py │ ├── CycGAN.py │ ├── WGAN.py │ ├── NCSNv2.py │ ├── GLOW.py │ ├── AdvVAE.py │ └── PresGAN.py ├── imgs └── logo_tasti_light.png ├── requirements ├── requirements.txt └── requirements_docker.txt ├── scripts ├── main.sh ├── job_docker.sh ├── README.md ├── job_apptainer.sh ├── shellcheck.sh ├── hadolint.sh ├── Text2Img_Lora.sh ├── Text2Img_Controlnet.sh ├── InstructPix2Pix.sh └── nbconveter.sh ├── AUTHORS.md ├── Dockerfile ├── cookiecutter-config-file.yml ├── docs ├── help.md ├── MODELRULES.md ├── Text2Img_LoRA.md ├── LOADING_ENV_VARIABLES.md ├── Text2Img_ControlNet.md ├── InstructPix2Pix.md ├── HierarchicalVAE.md ├── CycleGAN.md ├── SECURITY.md ├── PixelCNN.md ├── ConditionalGAN.md ├── ConditionalVAE.md ├── DiffusionAE.md ├── RealNVP.md ├── WassersteinGAN.md ├── VanillaFlow.md ├── VanillaVAE.md ├── Glow.md ├── FlowPlusPlus.md ├── RectifiedFlows.md ├── NCSNv2.md ├── DCGAN.md ├── CONTRIBUTING.md ├── AdversarialVAE.md ├── FlowMatching.md └── PrescribedGAN.md ├── setup.py ├── .dockerignore ├── pyproject.toml ├── .editorconfig ├── .flake8 ├── .envrc └── .env /VERSION: -------------------------------------------------------------------------------- 1 | 1.3.0 2 | -------------------------------------------------------------------------------- /data/raw/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/generativezoo/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/generativezoo/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/generativezoo/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /imgs/logo_tasti_light.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caetas/GenerativeZoo/HEAD/imgs/logo_tasti_light.png -------------------------------------------------------------------------------- /requirements/requirements.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caetas/GenerativeZoo/HEAD/requirements/requirements.txt -------------------------------------------------------------------------------- /requirements/requirements_docker.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caetas/GenerativeZoo/HEAD/requirements/requirements_docker.txt -------------------------------------------------------------------------------- /scripts/main.sh: -------------------------------------------------------------------------------- 1 | cd src/generativezoo 2 | python VanVAE.py \ 3 | --train \ 4 | --hidden_dims 16 32 64 \ 5 | --n_epochs 10 6 | -------------------------------------------------------------------------------- /scripts/job_docker.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | docker run --rm --gpus all --ipc=host \ 4 | --env-file ../.env \ 5 | -v $(pwd)/../:/app/ \ 6 | generativezoo /bin/bash scripts/main.sh -------------------------------------------------------------------------------- /scripts/README.md: -------------------------------------------------------------------------------- 1 | # Scripts 2 | 3 | This folder holds command line interface (CLI) scripts. This scripts typically provide entry points to kick off common 4 | tasks in your data science project, such as model training or inference. 5 | -------------------------------------------------------------------------------- /src/generativezoo/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Top-level package for generativezoo. 4 | 5 | A short description of the project. No quotes. 6 | """ 7 | from version import __version__ 8 | 9 | __author__ = "Francisco Caetano" 10 | -------------------------------------------------------------------------------- /scripts/job_apptainer.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd ../ 4 | # Execute the command inside the Apptainer container 5 | apptainer exec \ 6 | --nv \ 7 | --env-file .env \ 8 | --bind $(pwd)/:/app/ \ 9 | --pwd /app \ 10 | generativezoo.sif /bin/bash scripts/main.sh -------------------------------------------------------------------------------- /src/generativezoo/version.py: -------------------------------------------------------------------------------- 1 | from os.path import join 2 | 3 | from config import project_dir 4 | from loguru import logger 5 | 6 | with open(join(project_dir, "VERSION"), encoding="utf-8") as f: 7 | __version__ = f.read() 8 | 9 | 10 | if __name__ == "__main__": 11 | logger.info(__version__) 12 | 13 | # EOF 14 | -------------------------------------------------------------------------------- /AUTHORS.md: -------------------------------------------------------------------------------- 1 | Credits 2 | ======= 3 | 4 | Project Lead 5 | ---------------- 6 | 7 | * Francisco Caetano 8 | 9 | Project Contributors 10 | ------------ 11 | 12 | ```shell 13 | git shortlog -nse 14 | ``` 15 | 16 | Have an idea that you want to contribute to this project? Check out our [CONTRIBUTING](docs/CONTRIBUTING.md) guide! 17 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/pytorch:24.12-py3 2 | 3 | RUN apt-get update && apt-get install -y git 4 | 5 | COPY requirements/requirements_docker.txt /app/requirements/ 6 | RUN pip install -r /app/requirements/requirements_docker.txt 7 | RUN mkdir /app/data 8 | RUN mkdir /app/src 9 | RUN mkdir /app/models 10 | 11 | WORKDIR /app/ 12 | 13 | ENV PYTHONPATH="${PYTHONPATH}:/app/src/generativezoo" -------------------------------------------------------------------------------- /cookiecutter-config-file.yml: -------------------------------------------------------------------------------- 1 | # This file contains values from Cookiecutter 2 | 3 | default_context: 4 | project_name: "GenerativeZoo" 5 | repo_name: "GenerativeZoo" 6 | description: "Model Zoo for Generative Models" 7 | organization: "TU/e" 8 | license: "CC-BY-4.0" 9 | organization_email: "f.t.de.espirito.santo.e.caetano@tue.nl" 10 | line_length: "120" 11 | python_minimal_version: "3.9" 12 | -------------------------------------------------------------------------------- /scripts/shellcheck.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # 3 | # Ensure shell scripts conform to shellcheck. 4 | # 5 | 6 | set -eu 7 | 8 | readonly DEBUG=${DEBUG:-unset} 9 | if [ "${DEBUG}" != unset ]; then 10 | set -x 11 | fi 12 | 13 | if ! command -v shellcheck >/dev/null 2>&1; then 14 | echo 'shellcheck not installed; Please install shellcheck:' 15 | echo ' sudo apt install -y shellcheck' 16 | exit 1 17 | fi 18 | 19 | shellcheck "$@" 20 | -------------------------------------------------------------------------------- /docs/help.md: -------------------------------------------------------------------------------- 1 | # Welcome to MkDocs 2 | 3 | For full documentation visit [mkdocs.org](https://www.mkdocs.org). 4 | 5 | ## Commands 6 | 7 | * `mkdocs new [dir-name]` - Create a new project. 8 | * `mkdocs serve` - Start the live-reloading docs server. 9 | * `mkdocs build` - Build the documentation site. 10 | * `mkdocs -h` - Print help message and exit. 11 | 12 | ## Project layout 13 | 14 | mkdocs.yml # The configuration file. 15 | docs/ 16 | index.md # The documentation homepage. 17 | ... # Other markdown pages, images and other files. 18 | -------------------------------------------------------------------------------- /scripts/hadolint.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # 3 | # Haskell Dockerfile Linter 4 | # Dockerfile linter that helps you build best practice Docker images 5 | # https://docs.docker.com/engine/userguide/eng-image/dockerfile_best-practices 6 | # 7 | 8 | set -eu 9 | 10 | readonly DEBUG=${DEBUG:-unset} 11 | if [ "${DEBUG}" != unset ]; then 12 | set -x 13 | fi 14 | 15 | if ! command -v hadolint >/dev/null 2>&1; then 16 | echo 'hadolint not installed; Please install hadolint:' 17 | echo ' download the binary from https://github.com/hadolint/hadolint/releases' 18 | exit 1 19 | fi 20 | 21 | hadolint "$@" 22 | -------------------------------------------------------------------------------- /src/generativezoo/Text2Img_LoRA.py: -------------------------------------------------------------------------------- 1 | from diffusers import AutoPipelineForText2Image 2 | import torch 3 | import argparse 4 | from matplotlib import pyplot as plt 5 | 6 | def parse_args(): 7 | parser = argparse.ArgumentParser(description="Text2Img LoRA") 8 | parser.add_argument("--lora_model_path", type=str, default="./../../models/Text2Img_Lora/naruto/checkpoint-60", help="Path to LoRA model") 9 | args = parser.parse_args() 10 | return args 11 | 12 | args = parse_args() 13 | pipeline = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16).to("cuda") 14 | pipeline.load_lora_weights(args.lora_model_path, weight_name="pytorch_lora_weights.safetensors") 15 | 16 | while True: 17 | text = input("Enter prompt (0 to exit): ") 18 | if text == "0": 19 | break 20 | image = pipeline(text).images[0] 21 | plt.imshow(image) 22 | plt.show() -------------------------------------------------------------------------------- /scripts/Text2Img_Lora.sh: -------------------------------------------------------------------------------- 1 | export MODEL_NAME="stabilityai/stable-diffusion-2-1" 2 | export OUTPUT_DIR="./../../../../models/Text2Img_Lora/pokemons" 3 | export DATASET_PATH="./../../../../data/processed/pokemons" 4 | 5 | cd ./../src/generativezoo/models/SD 6 | 7 | accelerate launch --mixed_precision="fp16" Text2Img_Lora.py \ 8 | --pretrained_model_name_or_path=$MODEL_NAME \ 9 | --train_data_dir=$DATASET_PATH \ 10 | --dataloader_num_workers=8 \ 11 | --resolution=512 \ 12 | --center_crop \ 13 | --random_flip \ 14 | --train_batch_size=2 \ 15 | --gradient_accumulation_steps=1 \ 16 | --max_train_steps=15000 \ 17 | --learning_rate=1e-04 \ 18 | --max_grad_norm=1 \ 19 | --lr_scheduler="cosine" \ 20 | --lr_warmup_steps=0 \ 21 | --output_dir=${OUTPUT_DIR} \ 22 | --report_to=wandb \ 23 | --checkpointing_steps=500 \ 24 | --validation_prompt="A pokemon with blue wings." \ 25 | --seed=1337 \ 26 | --image_column="image" \ 27 | --caption_column="text" -------------------------------------------------------------------------------- /scripts/Text2Img_Controlnet.sh: -------------------------------------------------------------------------------- 1 | export MODEL_NAME="stabilityai/stable-diffusion-2-1" 2 | export OUTPUT_DIR="./../../../../models/Text2Img_Controlnet/pokemons" 3 | export DATASET_PATH="./../../../../data/processed/pokemons" 4 | 5 | cd ./../src/generativezoo/models/SD 6 | 7 | accelerate launch --mixed_precision="fp16" Text2Img_Controlnet.py \ 8 | --pretrained_model_name_or_path=$MODEL_DIR \ 9 | --output_dir=$OUTPUT_DIR \ 10 | --train_data_dir=$DATASET_PATH \ 11 | --resolution=512 \ 12 | --learning_rate=1e-5 \ 13 | --max_train_steps=10000 \ 14 | --checkpointing_steps=20 \ 15 | --validation_image "./../../../../data/processed/pokemons/conditioning_images/0003_mask.png" \ 16 | --validation_prompt "red circle pokemon with white dots" \ 17 | --train_batch_size=1 \ 18 | --report_to=wandb \ 19 | --image_column="image" \ 20 | --caption_column="text" \ 21 | --conditioning_image_column='conditioning_image' \ 22 | --gradient_accumulation_steps=1 \ 23 | --gradient_checkpointing \ 24 | --use_8bit_adam -------------------------------------------------------------------------------- /src/generativezoo/MaskGit.py: -------------------------------------------------------------------------------- 1 | from models.AR.MaskGiT import MaskGIT 2 | from utils.util import parse_args_MaskGiT 3 | from data.Dataloaders import * 4 | 5 | if __name__ == "__main__": 6 | args = parse_args_MaskGiT() 7 | 8 | if args.train: 9 | train_loader, input_size, channels = pick_dataset(args.dataset, batch_size = args.batch_size, normalize=True, num_workers=args.num_workers, size=args.size) 10 | val_loader, _, _ = pick_dataset(args.dataset, mode='val', batch_size = args.batch_size, normalize=True, num_workers=args.num_workers, size=args.size) 11 | model = MaskGIT(args, channels, input_size) 12 | model.train_model(train_loader, val_loader) 13 | 14 | elif args.sample: 15 | _, input_size, channels = pick_dataset(args.dataset, mode='val', batch_size = args.batch_size, normalize=True, num_workers=args.num_workers, size=args.size) 16 | model = MaskGIT(args, channels, input_size) 17 | model.load_checkpoint(args.checkpoint_vit) 18 | model.sample() -------------------------------------------------------------------------------- /src/generativezoo/VQGAN.py: -------------------------------------------------------------------------------- 1 | from models.GAN.VQGAN import VQModel 2 | from utils.util import parse_args_VQGAN 3 | from data.Dataloaders import * 4 | 5 | if __name__ == "__main__": 6 | args = parse_args_VQGAN() 7 | 8 | if args.train: 9 | train_loader, input_size, channels = pick_dataset(args.dataset, batch_size = args.batch_size, normalize=True, num_workers=args.num_workers, size=args.size) 10 | val_loader, _, _ = pick_dataset(args.dataset, mode='val', batch_size = args.batch_size, normalize=True, num_workers=args.num_workers, size=args.size) 11 | model = VQModel(args, channels, input_size) 12 | model.train_model(train_loader, val_loader) 13 | 14 | elif args.reconstruct: 15 | val_loader, input_size, channels = pick_dataset(args.dataset, mode='val', batch_size = args.batch_size, normalize=True, num_workers=args.num_workers, size=args.size) 16 | model = VQModel(args, channels, input_size) 17 | model.load_checkpoint(args.checkpoint) 18 | model.reconstruct(val_loader) -------------------------------------------------------------------------------- /src/generativezoo/SGM.py: -------------------------------------------------------------------------------- 1 | from models.SM.SGM import * 2 | from data.Dataloaders import * 3 | from utils.util import parse_args_SGM 4 | import torch 5 | import wandb 6 | 7 | if __name__ == '__main__': 8 | 9 | device = "cuda" if torch.cuda.is_available() else "cpu" 10 | args = parse_args_SGM() 11 | normalize = True 12 | 13 | 14 | if args.train: 15 | dataloader, input_size, channels = pick_dataset(args.dataset, 'train', args.batch_size, normalize=normalize, size=args.size, num_workers=args.num_workers) 16 | model = SGM(args, channels, input_size) 17 | model.train_model(dataloader) 18 | 19 | elif args.sample: 20 | _, input_size, channels = pick_dataset(args.dataset, 'val', args.batch_size, normalize=normalize, size=args.size) 21 | model = SGM(args, channels, input_size) 22 | model.model.load_state_dict(torch.load(args.checkpoint)) 23 | model.sample(args.num_samples) 24 | 25 | else: 26 | raise ValueError('Please specify at least one of the following: train, sample, outlier_detection') 27 | 28 | -------------------------------------------------------------------------------- /src/generativezoo/RF.py: -------------------------------------------------------------------------------- 1 | from data.Dataloaders import pick_dataset 2 | from models.FM.RectifiedFlows import RF 3 | from utils.util import parse_args_RectifiedFlows 4 | 5 | if __name__ == '__main__': 6 | 7 | args = parse_args_RectifiedFlows() 8 | 9 | 10 | if args.train: 11 | train_loader, input_size, channels = pick_dataset(args.dataset, batch_size = args.batch_size, normalize=True, num_workers=args.num_workers, size=args.size) 12 | model = RF(args, input_size, channels) 13 | model.train_model(train_loader) 14 | 15 | elif args.sample: 16 | _, input_size, channels = pick_dataset(args.dataset, batch_size = 1, normalize=True, size=args.size) 17 | model = RF(args, input_size, channels) 18 | model.load_checkpoint(args.checkpoint) 19 | model.sample(16) 20 | 21 | elif args.fid: 22 | _, input_size, channels = pick_dataset(args.dataset, batch_size = 1, normalize=True, size=args.size) 23 | model = RF(args, input_size, channels) 24 | model.load_checkpoint(args.checkpoint) 25 | model.fid_sample() -------------------------------------------------------------------------------- /scripts/InstructPix2Pix.sh: -------------------------------------------------------------------------------- 1 | export MODEL_NAME="stabilityai/stable-diffusion-2-1" 2 | export OUTPUT_DIR="./../../../../models/InstructPix2Pix/pokemons" 3 | export DATASET_PATH="./../../../../data/processed/pokemons" 4 | 5 | cd ./../src/generativezoo/models/SD 6 | 7 | accelerate launch --mixed_precision="fp16" InstructPix2Pix.py \ 8 | --pretrained_model_name_or_path=$MODEL_NAME \ 9 | --train_data_dir=$DATASET_PATH \ 10 | --output_dir=$OUTPUT_DIR \ 11 | --resolution=512 --random_flip \ 12 | --train_batch_size=4 --gradient_accumulation_steps=1 --gradient_checkpointing \ 13 | --max_train_steps=15000 \ 14 | --checkpointing_steps=5000 \ 15 | --learning_rate=5e-05 --max_grad_norm=1 --lr_warmup_steps=0 \ 16 | --conditioning_dropout_prob=0.05 \ 17 | --validation_image "./../../../../data/processed/pokemons/conditioning_images/0003_mask.png" \ 18 | --validation_prompt "red circle pokemon with white dots" \ 19 | --seed=42 \ 20 | --report_to=wandb \ 21 | --original_image_colum="original_image" \ 22 | --edited_image_column="edited_image" \ 23 | --edit_prompt_column="edit_prompt" -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from setuptools import find_packages, setup 4 | 5 | ROOT = Path(__file__).parent 6 | 7 | with open("README.md") as fh: 8 | long_description = fh.read() 9 | 10 | 11 | def find_requirements(filename): 12 | with (ROOT / "requirements" / filename).open() as f: 13 | return [s for s in [line.strip(" \n") for line in f] if not s.startswith("#") and s != ""] 14 | 15 | 16 | runtime_requires = find_requirements("requirements.txt") 17 | #dev_requires = find_requirements("requirements-dev.txt") 18 | #docs_require = find_requirements("requirements-docs.txt") 19 | 20 | 21 | setup( 22 | name="GenerativeZoo", 23 | version="1.2.0", 24 | author="Francisco Caetano", 25 | description="A Model Zoo for Generative Models.", 26 | long_description=long_description, 27 | long_description_content_type="text/markdown", 28 | packages=find_packages(where="src"), 29 | package_dir={"": "src"}, 30 | python_requires=">=3.9.0", 31 | install_requires=runtime_requires, 32 | #extras_require={ 33 | # "dev": dev_requires, 34 | # "docs": docs_require, 35 | #}, 36 | ) 37 | -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | # Ignore generated files 2 | **/*.pyc 3 | 4 | # JetBrains Pycharm 5 | .idea/ 6 | 7 | # VS Code 8 | .vscode/* 9 | !.vscode/settings.json 10 | !.vscode/tasks.json 11 | !.vscode/launch.json 12 | !.vscode/extensions.json 13 | *.code-workspace 14 | **/.vscode 15 | # Local History for Visual Studio Code 16 | .history/ 17 | 18 | # Logs 19 | logs/ 20 | 21 | #Ignoring all the markdown and class files 22 | *.md 23 | 24 | # Sphinx documentation 25 | docs/_build/ 26 | 27 | # mkdocs documentation 28 | site/ 29 | 30 | # Byte-compiled / optimized / DLL files 31 | __pycache__/ 32 | **/__pycache__/ 33 | *.pyc 34 | *.pyo 35 | *.pyd 36 | .Python 37 | *.py[cod] 38 | *$py.class 39 | .pytest_cache/ 40 | 41 | # C extensions 42 | *.so 43 | 44 | # Mypy 45 | .mypy_cache/ 46 | .dmypy.json 47 | dmypy.json 48 | 49 | # Virtual environment 50 | venv/ 51 | .venv/ 52 | .venv-docs/ 53 | .venv-dev/ 54 | .venv-note/ 55 | .venv-dempy/ 56 | 57 | # Mac 58 | .DS_Store 59 | .AppleDouble 60 | .LSOverride 61 | ._* 62 | 63 | # Git 64 | .git/ 65 | .gitignore 66 | 67 | # Dempy Cache 68 | .cache/ 69 | 70 | # Hydra 71 | outputs/ 72 | runs/ 73 | 74 | # Cruft 75 | .cruft.json 76 | 77 | # EOF 78 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | # Datasets 2 | 3 | This folder holds datasets that are versioned. Typically, these are small, dummy datasets for dev/test from your favourite 4 | CLI, IDE, or debugger. 5 | 6 | Any data that needs to be stored locally should be saved in this location. 7 | The sub-folders should be used as follows: 8 | 9 | - `external`: any data that will not be processed at all, such as reference data; 10 | - `raw`: any raw data before any processing; 11 | - `interim`: any raw data that has been partially processed and, for whatever reason, 12 | needs to be stored before further processing is completed; and 13 | - `processed`: any raw or interim data that has been fully processed into its final 14 | state. 15 | 16 | The paths for these directories are loaded as environment variables by the 17 | `.envrc` file. To load them in Python, use any or all of the following code: 18 | 19 | ```python 20 | import os 21 | 22 | # Load environment variables for the `data` folder, and its sub-folders 23 | DIR_DATA = os.getenv("DIR_DATA") 24 | DIR_DATA_EXTERNAL = os.getenv("DIR_DATA_EXTERNAL") 25 | DIR_DATA_RAW = os.getenv("DIR_DATA_RAW") 26 | DIR_DATA_INTERIM = os.getenv("DIR_DATA_INTERIM") 27 | DIR_DATA_PROCESSED = os.getenv("DIR_DATA_PROCESSED") 28 | ``` 29 | -------------------------------------------------------------------------------- /scripts/nbconveter.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # 3 | # Automatically exclude notebook outputs. 4 | # https://medium.com/somosfit/version-control-on-jupyter-notebooks-6b67a0cf12a3 5 | # https://pypi.org/project/nbstripout 6 | # https://github.com/jupyter/nbdime 7 | # https://github.com/mwouts/jupytext 8 | # https://www.fotonixx.com/posts/data-science-vcs/ 9 | # https://departmentfortransport.github.io/ds-processes/Coding_standards/ipython.html 10 | # https://zhauniarovich.com/post/2020/2020-06-clearing-jupyter-output/ 11 | 12 | if command -v jupyter >/dev/null 2>&1; then 13 | echo "Clean Outputs Cells and converting notebooks to scripts ..." 14 | # https://gist.github.com/tylerneylon/697065ca5906c185ec6dd3093b237164 15 | # Convert all new Jupyter notebooks to straight Python files for easier code reviews. 16 | for file in $(git diff --cached --name-only); do 17 | if [[ $file == *.ipynb ]]; then 18 | echo -e "Converting ${file} ..." 19 | jupyter nbconvert --ClearOutputPreprocessor.enabled=True --clear-output --inplace "${file}" 20 | jupyter nbconvert --to script "${file}" 21 | git add "${file%.*}".py 22 | fi 23 | done 24 | else 25 | echo "Jupyter Notebook not installed; Please install Jupyter Notebook." 26 | exit 1 27 | fi 28 | 29 | # EOF 30 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "generativezoo" 3 | version = "0.1.0" 4 | requires-python = ">=3.11" 5 | dependencies = [ 6 | "accelerate>=1.9.0", 7 | "bitsandbytes>=0.46.1", 8 | "datasets>=4.0.0", 9 | "diffusers>=0.34.0", 10 | "dotenv>=0.9.9", 11 | "einops>=0.8.1", 12 | "flask>=3.1.1", 13 | "fsspec>=2025.3.0", 14 | "h5py>=3.14.0", 15 | "lpips>=0.1.4", 16 | "matplotlib>=3.10.3", 17 | "mdurl>=0.1.2", 18 | "medmnist>=3.0.2", 19 | "ml-collections>=1.1.0", 20 | "monai>=1.5.0", 21 | "monai-generative>=0.2.3", 22 | "ninja>=1.11.1.4", 23 | "numpy>=2.3.2", 24 | "opencv-python>=4.11.0.86", 25 | "s3fs>=2025.3.0", 26 | "scikit-image>=0.25.2", 27 | "scikit-learn>=1.7.1", 28 | "scipy>=1.16.1", 29 | "toml>=0.10.2", 30 | "torch>=2.6.0", 31 | "torchdiffeq>=0.2.5", 32 | "torchvision>=0.21.0", 33 | "tqdm>=4.67.1", 34 | "transformers>=4.54.1", 35 | "wandb>=0.21.0", 36 | "zuko>=1.4.1", 37 | ] 38 | 39 | [[tool.uv.index]] 40 | name = "pytorch-cu126" 41 | url = "https://download.pytorch.org/whl/cu126" 42 | explicit = true 43 | 44 | [tool.uv.sources] 45 | torch = [ 46 | { index = "pytorch-cu126", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, 47 | ] 48 | torchvision = [ 49 | { index = "pytorch-cu126", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, 50 | ] 51 | -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | # EditorConfig helps developers define and maintain consistent coding styles between different editors and IDEs 2 | # More information at https://EditorConfig.org 3 | 4 | root = true 5 | 6 | # Unix-style newlines with a newline ending every file 7 | [*] 8 | end_of_line = lf 9 | insert_final_newline = true 10 | trim_trailing_whitespace = true 11 | charset = utf-8 12 | indent_style = space 13 | max_line_length = 120 14 | 15 | # Python: PEP8 defines 4 spaces for indentation 16 | [*.{py,java,r,R,sh}] 17 | indent_size = 4 18 | 19 | # The JSON files contain newlines inconsistently 20 | [*.json] 21 | indent_size = 2 22 | insert_final_newline = false 23 | 24 | [*.{yml,yaml}] 25 | indent_size = 2 26 | 27 | [*.{md,Rmd,rst}] 28 | max_line_length = 79 29 | trim_trailing_whitespace = false 30 | indent_size = 2 31 | 32 | # Tabs matter for Makefile and .gitmodules 33 | [{makefile*,Makefile*,*.mk,*.mak,*.makefile,*.Makefile,GNUmakefile,BSDmakefile,make.bat,Makevars*,*.gitmodules}] 34 | indent_style = tab 35 | insert_final_newline = false 36 | 37 | # Placeholder files 38 | [{*.gitkeep,__init__.py,.envrc}] 39 | insert_final_newline = false 40 | 41 | [{LICENSE, VERSION*, requirements*}] 42 | insert_final_newline = false 43 | 44 | [*.c] 45 | max_line_length = 100 46 | 47 | [*.h] 48 | max_line_length = 100 49 | 50 | [.git/*] 51 | trim_trailing_whitespace = false 52 | 53 | # Jenkinsfiles and Dockerfile files only 54 | [*.{Jenkinsfile,dockerfile}] 55 | indent_style = tab 56 | indent_size = 4 57 | -------------------------------------------------------------------------------- /src/generativezoo/Text2Img_ControlNet.py: -------------------------------------------------------------------------------- 1 | from diffusers import StableDiffusionControlNetPipeline, ControlNetModel 2 | from diffusers.utils import load_image 3 | import torch 4 | import argparse 5 | from matplotlib import pyplot as plt 6 | 7 | def parse_args(): 8 | parser = argparse.ArgumentParser(description="Text2Img LoRA") 9 | parser.add_argument("--cnet_model_path", type=str, default="./../../models/Text2Img_ControlNet/pokemons/checkpoint-200/controlnet", help="Path to ControlNet model") 10 | parser.add_argument("--cond_image_path", type=str, default="./../../data/processed/pokemons/conditioning_images/0003_mask.png", help="Path to conditioning image") 11 | args = parser.parse_args() 12 | return args 13 | 14 | args = parse_args() 15 | 16 | controlnet = ControlNetModel.from_pretrained(args.cnet_model_path, torch_dtype=torch.float16) 17 | pipeline = StableDiffusionControlNetPipeline.from_pretrained( 18 | "stabilityai/stable-diffusion-2-1", controlnet=controlnet, torch_dtype=torch.float16 19 | ).to("cuda") 20 | 21 | control_image = load_image(args.cond_image_path) 22 | 23 | while True: 24 | text = input("Enter prompt (0 to exit): ") 25 | if text == "0": 26 | break 27 | image = pipeline(text, num_inference_steps=20, image=control_image).images[0] 28 | plt.imshow(image) 29 | plt.show() 30 | new_cond = input("Enter new conditioning image (0 to keep the same): ") 31 | if new_cond == "0": 32 | continue 33 | control_image = load_image(new_cond) -------------------------------------------------------------------------------- /src/generativezoo/GPT.py: -------------------------------------------------------------------------------- 1 | from models.AR.GPT import VQGAN_GPT 2 | from utils.util import parse_args_GPT 3 | from data.Dataloaders import * 4 | 5 | if __name__ == "__main__": 6 | args = parse_args_GPT() 7 | 8 | if args.train: 9 | train_loader, input_size, channels = pick_dataset(args.dataset, batch_size = args.batch_size, normalize=True, num_workers=args.num_workers, size=args.size) 10 | val_loader, _, _ = pick_dataset(args.dataset, mode='val', batch_size = args.batch_size, normalize=True, num_workers=args.num_workers, size=args.size) 11 | model = VQGAN_GPT(args, channels, input_size) 12 | model.train_model(train_loader, val_loader) 13 | 14 | elif args.sample: 15 | _, input_size, channels = pick_dataset(args.dataset, mode='val', batch_size = args.batch_size, normalize=True, num_workers=args.num_workers, size=args.size) 16 | model = VQGAN_GPT(args, channels, input_size) 17 | model.load_checkpoint(args.checkpoint_gpt) 18 | model.sample() 19 | 20 | elif args.outlier_detection: 21 | in_loader, input_size, channels = pick_dataset(args.dataset, mode='val', batch_size = args.batch_size, normalize=True, num_workers=args.num_workers, size=args.size) 22 | out_loader, _, _ = pick_dataset(args.out_dataset, mode='val', batch_size = args.batch_size, normalize=True, num_workers=args.num_workers, size=input_size) 23 | model = VQGAN_GPT(args, channels, input_size) 24 | model.load_checkpoint(args.checkpoint_gpt) 25 | model.outlier_detection(in_loader, out_loader) -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | # Unfortunately, flake8 does not support pyproject.toml configuration. 2 | # https://github.com/PyCQA/flake8/issues/234 3 | [flake8] 4 | per-file-ignores = 5 | __init__.py:F401 6 | show-source = True 7 | count= True 8 | statistics = True 9 | # https://www.flake8rules.com 10 | # E203 = Whitespace before ‘:' 11 | # E265 = comment blocks like @{ section, which it can't handle 12 | # E266 = too many leading '#' for block comment 13 | # E731 = do not assign a lambda expression, use a def 14 | # W293 = Blank line contains whitespace 15 | # W503 = Line break before binary operator 16 | # E704 = multiple statements in one line - used for @override 17 | # TC002 = move third party import to TYPE_CHECKING 18 | # ANN = flake8-annotations 19 | # TC, TC2 = flake8-type-checking 20 | # B = flake8-bugbear 21 | # S = flake8-bandit 22 | # D = flake8-docstrings 23 | # S = flake8-bandit 24 | # F are errors reported by pyflakes 25 | # E and W are warnings and errors reported by pycodestyle 26 | # C are violations reported by mccabe 27 | # BLK = flake8-black 28 | # DAR = darglint 29 | # SC = flake8-spellcheck 30 | ignore = E203, E211, E265, E501, E999, F401, F821, W503, W505, SC100, SC200, C400, C401, C402, B008 31 | max-line-length = 120 32 | max-doc-length = 120 33 | import-order-style = google 34 | docstring-convention = google 35 | inline-quotes = " 36 | strictness=short 37 | dictionaries=en_US,python,technical,pandas 38 | min-python-version = 3.7.0 39 | exclude = .git,.tox,.nox,venv,.venv,.venv-docs,.venv-dev,.venv-note,.venv-dempy,docs,test 40 | max-complexity = 10 41 | #spellcheck-targets=comments 42 | -------------------------------------------------------------------------------- /src/generativezoo/HVAE.py: -------------------------------------------------------------------------------- 1 | from models.VAE.HierarchicalVAE import * 2 | from data.Dataloaders import * 3 | from utils.util import parse_args_HierarchicalVAE 4 | import torch 5 | import wandb 6 | 7 | if __name__ == '__main__': 8 | 9 | args = parse_args_HierarchicalVAE() 10 | 11 | size = None 12 | 13 | if args.train: 14 | dataloader, img_size, channels = pick_dataset(args.dataset, size=size, batch_size=args.batch_size, num_workers=args.num_workers) 15 | if not args.no_wandb: 16 | wandb.init(project='HierarchicalVAE', 17 | config={ 18 | 'latent_dim': args.latent_dim, 19 | 'img_size': img_size, 20 | 'channels': channels, 21 | 'batch_size': args.batch_size, 22 | 'epochs': args.n_epochs, 23 | 'dataset': args.dataset 24 | }, 25 | name=f'HierarchicalVAE_{args.dataset}') 26 | 27 | model = HierarchicalVAE(args.latent_dim, (img_size, img_size), channels, args.no_wandb) 28 | model.train_model(dataloader, args) 29 | wandb.finish() 30 | 31 | if args.sample: 32 | _, img_size, channels = pick_dataset(args.dataset, mode='val', size=size, batch_size=args.batch_size, num_workers=args.num_workers) 33 | model = HierarchicalVAE(args.latent_dim, (img_size, img_size), channels) 34 | if args.checkpoint is not None: 35 | model.load_checkpoint(args.checkpoint) 36 | model.sample(16) -------------------------------------------------------------------------------- /src/generativezoo/CondDDPM.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from data.Dataloaders import * 3 | from models.DDPM.ConditionalDDPM import * 4 | from utils.util import parse_args_CDDPM 5 | import wandb 6 | 7 | if __name__ == '__main__': 8 | 9 | device = "cuda" if torch.cuda.is_available() else "cpu" 10 | args = parse_args_CDDPM() 11 | normalize = True 12 | 13 | if args.train: 14 | train_dataloader, input_size, channels = pick_dataset(args.dataset, 'train', args.batch_size, normalize=normalize, size=args.size, num_workers=args.num_workers) 15 | model = ConditionalDDPM(in_channels=channels, input_size=input_size, args=args) 16 | model.train_model(train_dataloader) 17 | 18 | elif args.sample: 19 | _, input_size, channels = pick_dataset(args.dataset, 'train', args.batch_size, normalize=normalize, size=args.size) 20 | model = ConditionalDDPM(in_channels=channels, input_size=input_size, args=args) 21 | if args.checkpoint is not None: 22 | model.model.load_state_dict(torch.load(args.checkpoint, weights_only=False)) 23 | model.sample() 24 | 25 | elif args.fid: 26 | _, input_size, channels = pick_dataset(args.dataset, 'train', args.batch_size, normalize=normalize, size=args.size) 27 | model = ConditionalDDPM(in_channels=channels, input_size=input_size, args=args) 28 | if args.checkpoint is not None: 29 | model.model.load_state_dict(torch.load(args.checkpoint, weights_only=False)) 30 | model.fid_sample() 31 | 32 | else: 33 | raise ValueError('Please specify at least one of the following: train, sample') -------------------------------------------------------------------------------- /src/generativezoo/InstructPix2Pix.py: -------------------------------------------------------------------------------- 1 | import PIL 2 | import requests 3 | import torch 4 | from diffusers import StableDiffusionInstructPix2PixPipeline, UNet2DConditionModel 5 | from diffusers.utils import load_image 6 | import argparse 7 | from matplotlib import pyplot as plt 8 | 9 | def parse_args(): 10 | arg_parser = argparse.ArgumentParser(description="InstructPix2Pix") 11 | arg_parser.add_argument("--pix2pix_model", type=str, default='./../../models/InstructPix2Pix/checkpoint-20/unet', help="The name of the Pix2Pix model to use") 12 | arg_parser.add_argument("--image_path", type=str, default="./../../data/processed/pokemons/conditioning_images/0003_mask.png", help="The path to the image to edit") 13 | return arg_parser.parse_args() 14 | 15 | args = parse_args() 16 | 17 | pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained("stabilityai/stable-diffusion-2-1", unet = UNet2DConditionModel.from_pretrained(args.pix2pix_model, torch_dtype=torch.float16), torch_dtype=torch.float16).to("cuda") 18 | generator = torch.Generator("cuda").manual_seed(0) 19 | 20 | num_inference_steps = 20 21 | image_guidance_scale = 1.5 22 | guidance_scale = 10 23 | 24 | original_image = load_image(args.image_path) 25 | 26 | while True: 27 | text = input("Enter prompt (0 to exit): ") 28 | if text == "0": 29 | break 30 | image = pipeline(text,image=original_image,num_inference_steps=num_inference_steps,image_guidance_scale=image_guidance_scale,guidance_scale=guidance_scale,generator=generator).images[0] 31 | plt.imshow(image) 32 | plt.show() 33 | new_original = input("Enter new original image (0 to keep the same): ") 34 | if new_original == "0": 35 | continue 36 | original_image = load_image(new_original) 37 | -------------------------------------------------------------------------------- /src/generativezoo/data/CycleGAN_Dataloaders.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | from torch.utils.data import DataLoader, Dataset 3 | from PIL import Image 4 | import os 5 | 6 | 7 | class Horse2Zebra(Dataset): 8 | def __init__(self, root, dataset, transform=None, train = True, distribution = 'A'): 9 | self.root = root 10 | self.transform = transform 11 | self.dataset = dataset 12 | self.train = train 13 | self.distribution = distribution 14 | if self.train: 15 | self.files = sorted(os.listdir(os.path.join(root, dataset, 'train' + distribution))) 16 | else: 17 | self.files = sorted(os.listdir(os.path.join(root, dataset, 'test' + distribution))) 18 | 19 | def __len__(self): 20 | return len(self.files) 21 | 22 | def __getitem__(self, index): 23 | if self.train: 24 | img = Image.open(os.path.join(self.root, self.dataset, 'train' + self.distribution, self.files[index])).convert('RGB') 25 | else: 26 | img = Image.open(os.path.join(self.root, self.dataset, 'test' + self.distribution, self.files[index])).convert('RGB') 27 | if self.transform is not None: 28 | img = self.transform(img) 29 | return img 30 | 31 | def get_horse2zebra_dataloader(root, dataset, batch_size, train = True, distribution = 'A', input_size = 128): 32 | transform = transforms.Compose([ 33 | transforms.Resize((input_size, input_size)), 34 | transforms.ToTensor(), 35 | transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)) 36 | ]) 37 | set = Horse2Zebra(root, dataset, transform, train, distribution) 38 | dataloader = DataLoader(set, batch_size=batch_size, shuffle=True, num_workers=0, drop_last=True) 39 | return dataloader -------------------------------------------------------------------------------- /.envrc: -------------------------------------------------------------------------------- 1 | # Orchestration file to load environment variables from the `.env` and `.secrets` files. 2 | # Environment variables go here, and can be read in by Python using `os.getenv`: 3 | # 4 | # Only used by systems with `direnv` (https://direnv.net/) installed. Environment 5 | # variables can be read in by Python using `os.getenv` _without_ using `python-dotenv`: 6 | # 7 | # -------------------------------------------------------- 8 | # import os 9 | # 10 | # # Example variable 11 | # EXAMPLE_VARIABLE = os.getenv("EXAMPLE_VARIABLE") 12 | # -------------------------------------------------------- 13 | # 14 | # To ensure the `sed` command below works correctly, make sure all file paths in environment variables are absolute 15 | # (recommended), or are relative paths using other environment variables (works for Python users only). Environment 16 | # variable names are expected to contain letters, numbers or underscores only. 17 | # 18 | # DO NOT STORE SECRETS HERE - this file is version-controlled! You should store secrets in a `.secrets` file, which is 19 | # not version-controlled - this can then be sourced here, using `source_env ".secrets"`. 20 | 21 | # Extract the variables to `.env` if required. Note `.env` is NOT version-controlled, so `.secrets` will not be committed 22 | #sed -n 's/^export \(.*\)$/\1/p' .envrc .secrets | sed -e 's?$(pwd)?'"$(pwd)"'?g' | sed -e 's?$\([a-zA-Z0-9_]\{1,\}\)?${\1}?g' >> .env 23 | 24 | # Add the working directory to `PYTHONPATH`; allows Jupyter notebooks in the `notebooks` folder to import `src/generativezoo` 25 | export PYTHONPATH="$PYTHONPATH:$(pwd)/src/generativezoo" 26 | 27 | # Load the `.env` file 28 | dotenv .env 29 | 30 | # Import secrets from an untracked file `.secrets` (if it exists) 31 | dotenv_if_exists .secrets 32 | # Activate the virtual environment if it exists 33 | source .venv/bin/activate 34 | -------------------------------------------------------------------------------- /src/generativezoo/CondGan.py: -------------------------------------------------------------------------------- 1 | from models.GAN.ConditionalGAN import * 2 | from data.Dataloaders import * 3 | from utils.util import parse_args_CondGAN 4 | import torch 5 | import wandb 6 | 7 | if __name__ == '__main__': 8 | 9 | args = parse_args_CondGAN() 10 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 11 | 12 | size = None 13 | 14 | if args.train: 15 | if not args.no_wandb: 16 | wandb.init(project='ConditionalGAN', 17 | config={ 18 | 'dataset': args.dataset, 19 | 'batch_size': args.batch_size, 20 | 'n_epochs': args.n_epochs, 21 | 'latent_dim': args.latent_dim, 22 | 'd': args.d, 23 | 'lr': args.lr, 24 | 'beta1': args.beta1, 25 | 'beta2': args.beta2, 26 | 'sample_and_save_freq': args.sample_and_save_freq 27 | }, 28 | name = 'ConditionalGAN_{}'.format(args.dataset)) 29 | train_dataloader, input_size, channels = pick_dataset(dataset_name = args.dataset, batch_size=args.batch_size, normalize = True, size = size, num_workers=args.num_workers) 30 | model = ConditionalGAN(img_size=input_size, channels=channels, args=args) 31 | model.train_model(train_dataloader) 32 | 33 | elif args.sample: 34 | _, input_size, channels = pick_dataset(dataset_name = args.dataset, batch_size=1, normalize = True, size = size) 35 | model = Generator(n_classes = args.n_classes, latent_dim = args.latent_dim, channels = channels, d = args.d).to(device) 36 | model.load_state_dict(torch.load(args.checkpoint)) 37 | model.eval() 38 | model.sample(n_samples = args.n_samples, device = device) 39 | else: 40 | raise Exception('Please specify either --train or --sample') 41 | -------------------------------------------------------------------------------- /src/generativezoo/config.py: -------------------------------------------------------------------------------- 1 | import platform 2 | from os import environ, getenv 3 | from os.path import dirname, join 4 | from pathlib import Path 5 | 6 | import dotenv 7 | 8 | # Project root 9 | project_dir = dirname(dirname(dirname(__file__))) 10 | 11 | # Load the environment variables from the `.env` file, overriding any system environment variables 12 | env_path = join(project_dir, ".env") 13 | dotenv.load_dotenv(env_path, override=True) 14 | 15 | # Load secrets from the `.secrets` file, overriding any system environment variables 16 | secrets_path = join(project_dir, ".secrets") 17 | dotenv.load_dotenv(secrets_path, override=True) 18 | 19 | # Some common paths 20 | _reports_dir = Path(str(getenv("DIR_REPORTS"))) 21 | report_dir = join(project_dir, _reports_dir) 22 | 23 | _figures_dir = Path(str(getenv("DIR_FIGURES"))) 24 | figures_dir = join(project_dir, _figures_dir) 25 | 26 | _models_dir = Path(str(getenv("DIR_MODELS"))) 27 | models_dir = join(project_dir, _models_dir) 28 | 29 | _notebook_dir = Path(str(getenv("DIR_NOTEBOOKS"))) 30 | notebook_dir = join(project_dir, _notebook_dir) 31 | 32 | _data_dir = Path(str(getenv("DIR_DATA"))) 33 | data_dir = join(project_dir, _data_dir) 34 | 35 | _data_raw_dir = Path(str(getenv("DIR_DATA_RAW"))) 36 | data_raw_dir = join(project_dir, _data_raw_dir) 37 | 38 | _data_interim_dir = Path(str(getenv("DIR_DATA_INTERIM"))) 39 | data_interim_dir = join(project_dir, _data_interim_dir) 40 | 41 | _data_processed_dir = Path(str(getenv("DIR_DATA_PROCESSED"))) 42 | data_processed_dir = join(project_dir, _data_processed_dir) 43 | 44 | # CUDA Enable 45 | ENABLE_CUDA = True 46 | if not ENABLE_CUDA: 47 | environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 48 | environ["CUDA_VISIBLE_DEVICES"] = "-1" 49 | environ["USE_CPU"] = "1" 50 | 51 | # Hydra 52 | environ["HYDRA_FULL_ERROR"] = "1" 53 | 54 | # log to mlflow 55 | LOG_TO_MLFLOW = False 56 | 57 | _IS_WINDOWS = platform.system() == "Windows" 58 | 59 | # EOF 60 | -------------------------------------------------------------------------------- /src/generativezoo/CondFM.py: -------------------------------------------------------------------------------- 1 | from models.FM.CondFlowMatching import CondFlowMatching 2 | from data.Dataloaders import * 3 | from utils.util import parse_args_CondFlowMatching 4 | import wandb 5 | 6 | if __name__ == '__main__': 7 | 8 | args = parse_args_CondFlowMatching() 9 | 10 | if args.train: 11 | train_loader, input_size, channels = pick_dataset(args.dataset, batch_size = args.batch_size, normalize=True, num_workers=args.num_workers, size=args.size) 12 | model = CondFlowMatching(args, input_size, channels) 13 | model.train_model(train_loader) 14 | wandb.finish() 15 | 16 | elif args.sample: 17 | _, input_size, channels = pick_dataset(args.dataset, batch_size = 1, normalize=True, size=args.size) 18 | model = CondFlowMatching(args, input_size, channels) 19 | model.load_checkpoint(args.checkpoint) 20 | model.sample(args.num_samples, train=False) 21 | elif args.fid: 22 | _, input_size, channels = pick_dataset(args.dataset, batch_size = 1, normalize=True, size=args.size) 23 | model = CondFlowMatching(args, input_size, channels) 24 | model.load_checkpoint(args.checkpoint) 25 | model.fid_sample() 26 | 27 | elif args.translation: 28 | val_loader, input_size, channels = pick_dataset(args.dataset, batch_size = 16, normalize=True, size=args.size, mode='val') 29 | model = CondFlowMatching(args, input_size, channels) 30 | model.load_checkpoint(args.checkpoint) 31 | model.image_translation(val_loader) 32 | elif args.classification: 33 | val_loader, input_size, channels = pick_dataset(args.dataset, batch_size = args.batch_size, normalize=True, size=args.size, mode='val') 34 | model = CondFlowMatching(args, input_size, channels) 35 | model.load_checkpoint(args.checkpoint) 36 | model.classification(val_loader) 37 | else: 38 | raise ValueError("Invalid mode, please specify train or sample mode.") 39 | -------------------------------------------------------------------------------- /src/generativezoo/VanFlow.py: -------------------------------------------------------------------------------- 1 | from models.NF.VanillaFlow import VanillaFlow 2 | from utils.util import parse_args_VanillaFlow 3 | from data.Dataloaders import * 4 | import wandb 5 | 6 | if __name__ == '__main__': 7 | 8 | args = parse_args_VanillaFlow() 9 | 10 | if args.train: 11 | in_loader, img_size, channels = pick_dataset(args.dataset, 'train', args.batch_size, num_workers=args.num_workers) 12 | if not args.no_wandb: 13 | wandb.init(project = "VanillaFlow", 14 | config = { 15 | "dataset": args.dataset, 16 | "batch_size": args.batch_size, 17 | "epochs": args.n_epochs, 18 | "lr": args.lr, 19 | "img_size": img_size, 20 | "c_hidden": args.c_hidden, 21 | "n_layers": args.n_layers, 22 | "multi_scale": args.multi_scale, 23 | "vardeq": args.vardeq, 24 | }, 25 | name = f"VanillaFlow_{args.dataset}") 26 | 27 | model = VanillaFlow(img_size, channels, args) 28 | model.train_model(in_loader, args) 29 | wandb.finish() 30 | 31 | elif args.sample: 32 | _, img_size, channels = pick_dataset(args.dataset, 'val', args.batch_size) 33 | model = VanillaFlow(img_size, channels, args) 34 | if args.checkpoint is not None: 35 | model.flows.load_state_dict(torch.load(args.checkpoint)) 36 | model.sample(train=False) 37 | 38 | elif args.outlier_detection: 39 | in_loader, img_size, channels = pick_dataset(args.dataset, 'val', args.batch_size) 40 | out_loader, _, _ = pick_dataset(args.dataset, 'val', args.batch_size, size=img_size) 41 | model = VanillaFlow(img_size, channels, args) 42 | if args.checkpoint is not None: 43 | model.flows.load_state_dict(torch.load(args.checkpoint)) 44 | model.outlier_detection(in_loader, out_loader) -------------------------------------------------------------------------------- /docs/MODELRULES.md: -------------------------------------------------------------------------------- 1 | # Model Creation Guidelines 2 | 3 | Thank you for your interest in contributing to our repository! Below are the guidelines for creating models: 4 | 5 | 1. **Flexibility**: You are encouraged to design and implement models as you see fit. There are no restrictions on the architecture or techniques used. 6 | 7 | 2. **Self-contained Class**: Each model should be encapsulated within a self-contained class. This class should include, at minimum, the following methods: 8 | - `train_model`: This method should contain the necessary code for training the model. It should accept training data as input and update the model parameters accordingly. 9 | - `sample` (optional): If applicable, this method should generate samples from the trained model. It may take additional parameters for controlling the sampling process. 10 | 11 | 3. **Documentation**: Provide clear and concise documentation within the class to explain its functionality and usage. Document any additional features or functionalities beyond the basic training and sampling methods. 12 | 13 | 4. **Additional Features**: You are welcome to implement additional features within the model class, such as outlier detection, data augmentation, or specialized sampling techniques. Document these features thoroughly to facilitate understanding and usage. 14 | 15 | 5. **Readability and Maintainability**: Write clean, readable, and well-commented code to ensure that others can easily understand and modify your implementation if needed. 16 | 17 | 6. **Dependencies**: Minimize external dependencies and ensure that all required libraries are listed in the repository's `requirements.txt` file. 18 | 19 | 7. **Testing**: Whenever possible, include unit tests to verify the correctness and robustness of your implementation. 20 | 21 | 8. **Licensing**: Ensure that your code complies with the repository's licensing terms and that you have the necessary permissions to contribute it. 22 | 23 | By following these guidelines, you can create models that are easy to understand, use, and integrate into our repository. Thank you for your contributions! 24 | -------------------------------------------------------------------------------- /src/generativezoo/FM.py: -------------------------------------------------------------------------------- 1 | from models.FM.FlowMatching import FlowMatching 2 | from data.Dataloaders import * 3 | from utils.util import parse_args_FlowMatching 4 | import wandb 5 | 6 | if __name__ == '__main__': 7 | 8 | args = parse_args_FlowMatching() 9 | 10 | if args.train: 11 | train_loader, input_size, channels = pick_dataset(args.dataset, batch_size = args.batch_size, normalize=True, num_workers=args.num_workers, size=args.size) 12 | model = FlowMatching(args, input_size, channels) 13 | model.train_model(train_loader) 14 | wandb.finish() 15 | 16 | elif args.sample: 17 | _, input_size, channels = pick_dataset(args.dataset, batch_size = 1, normalize=True, size=args.size) 18 | model = FlowMatching(args, input_size, channels) 19 | model.load_checkpoint(args.checkpoint) 20 | model.sample(args.num_samples, train=False) 21 | 22 | elif args.outlier_detection: 23 | in_loader, input_size, channels = pick_dataset(args.dataset, mode='val', batch_size = args.batch_size, normalize=True, size=args.size) 24 | out_loader, _, _ = pick_dataset(args.out_dataset, mode='val', batch_size = args.batch_size, normalize=True, size=input_size) 25 | model = FlowMatching(args, input_size, channels) 26 | model.load_checkpoint(args.checkpoint) 27 | model.outlier_detection(in_loader, out_loader) 28 | 29 | elif args.interpolation: 30 | in_loader, input_size, channels = pick_dataset(args.dataset, mode='val', batch_size = args.batch_size, normalize=True, size=args.size) 31 | model = FlowMatching(args, input_size, channels) 32 | model.load_checkpoint(args.checkpoint) 33 | model.interpolate(in_loader) 34 | elif args.fid: 35 | _, input_size, channels = pick_dataset(args.dataset, mode='val', batch_size = args.batch_size, normalize=True, size=args.size) 36 | model = FlowMatching(args, input_size, channels) 37 | model.load_checkpoint(args.checkpoint) 38 | model.fid_sample(args.batch_size) 39 | else: 40 | raise ValueError("Invalid mode, please specify train or sample mode.") 41 | -------------------------------------------------------------------------------- /src/generativezoo/CondVAE.py: -------------------------------------------------------------------------------- 1 | from models.VAE.ConditionalVAE import * 2 | from data.Dataloaders import * 3 | import torch 4 | from utils.util import parse_args_ConditionalVAE 5 | import wandb 6 | 7 | if __name__ == '__main__': 8 | 9 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 10 | 11 | args = parse_args_ConditionalVAE() 12 | 13 | size = None 14 | 15 | if args.train: 16 | # train dataloader 17 | train_loader, in_shape, in_channels = pick_dataset(args.dataset, batch_size = args.batch_size, normalize=True, size=size, num_workers=args.num_workers) 18 | if not args.no_wandb: 19 | wandb.init(project='CVAE', 20 | 21 | config={ 22 | 'dataset': args.dataset, 23 | 'batch_size': args.batch_size, 24 | 'n_epochs': args.n_epochs, 25 | 'lr': args.lr, 26 | 'latent_dim': args.latent_dim, 27 | 'hidden_dims': args.hidden_dims, 28 | 'input_size': in_shape, 29 | 'channels': in_channels, 30 | 'num_classes': args.num_classes, 31 | 'loss_type': args.loss_type, 32 | 'kld_weight': args.kld_weight 33 | }, 34 | 35 | name = 'CVAE_{}'.format(args.dataset)) 36 | # create model 37 | model = ConditionalVAE(input_shape=in_shape, input_channels=in_channels, args=args) 38 | # train model 39 | model.train_model(train_loader, args.n_epochs) 40 | 41 | elif args.sample: 42 | _, in_shape, in_channels = pick_dataset(args.dataset, batch_size = args.batch_size, normalize=True, size=size) 43 | model = ConditionalVAE(input_shape=in_shape, input_channels=in_channels, args=args) 44 | model.load_state_dict(torch.load(args.checkpoint)) 45 | model.sample(title="Sample", train = False) 46 | else: 47 | raise ValueError("Invalid mode. Please specify train or sample") -------------------------------------------------------------------------------- /src/generativezoo/DAE.py: -------------------------------------------------------------------------------- 1 | from models.DDPM.MONAI_DiffAE import DiffAE 2 | import torch 3 | from data.Dataloaders import * 4 | from utils.util import parse_args_DiffAE 5 | import wandb 6 | 7 | 8 | if __name__ == '__main__': 9 | 10 | device = "cuda" if torch.cuda.is_available() else "cpu" 11 | args = parse_args_DiffAE() 12 | 13 | size = None 14 | 15 | if args.train: 16 | train_dataloader, input_size, channels = pick_dataset(args.dataset, 'train', args.batch_size, normalize=True, size=size, num_workers=args.num_workers) 17 | if not args.no_wandb: 18 | wandb.init(project='DiffAE', 19 | config={ 20 | 'dataset': args.dataset, 21 | 'batch_size': args.batch_size, 22 | 'n_epochs': args.n_epochs, 23 | 'lr': args.lr, 24 | 'embedding_dim': args.embedding_dim, 25 | 'timesteps': args.timesteps, 26 | 'sample_timesteps': args.sample_timesteps, 27 | 'model_channels': args.model_channels, 28 | 'attention_levels': args.attention_levels, 29 | 'num_res_blocks': args.num_res_blocks, 30 | 'input_size': input_size, 31 | 'channels': channels, 32 | }, 33 | name = 'DiffAE_{}'.format(args.dataset)) 34 | model = DiffAE(args, channels) 35 | model.train_model(train_dataloader, train_dataloader) 36 | wandb.finish() 37 | 38 | elif args.manipulate: 39 | train_dataloader, input_size, channels = pick_dataset(args.dataset, 'train', args.batch_size, normalize=True, size = size) 40 | val_dataloader, _, _ = pick_dataset(args.dataset, 'val', args.batch_size, normalize=True, size=size) 41 | model = DiffAE(args, channels) 42 | model.unet.load_state_dict(torch.load(args.checkpoint)) 43 | model.linear_regression(train_dataloader, val_dataloader) 44 | model.manipulate_latent(val_dataloader) 45 | -------------------------------------------------------------------------------- /src/generativezoo/P-CNN.py: -------------------------------------------------------------------------------- 1 | from models.AR.PixelCNN import * 2 | from data.Dataloaders import * 3 | from utils.util import parse_args_PixelCNN 4 | import wandb 5 | 6 | if __name__ == '__main__': 7 | 8 | args = parse_args_PixelCNN() 9 | 10 | size = None 11 | 12 | if args.train: 13 | dataloader, img_size, channels = pick_dataset(args.dataset, normalize=False, batch_size=args.batch_size, size=size, num_workers=args.num_workers) 14 | if not args.no_wandb: 15 | wandb.init(project="PixelCNN", 16 | config = { 17 | "batch_size": args.batch_size, 18 | "hidden_channels": args.hidden_channels, 19 | "n_epochs": args.n_epochs, 20 | "lr": args.lr, 21 | "gamma": args.gamma, 22 | "image_size": img_size, 23 | "dataset": args.dataset, 24 | "channels": channels 25 | }, 26 | name=f"PixelCNN_{args.dataset}" 27 | ) 28 | 29 | model = PixelCNN(channels, args.hidden_channels, args.no_wandb) 30 | model.train_model(dataloader, args, img_size) 31 | wandb.finish() 32 | 33 | elif args.sample: 34 | _, img_size, channels = pick_dataset(args.dataset, normalize=False, batch_size=args.batch_size, size=size) 35 | model = PixelCNN(channels, args.hidden_channels) 36 | 37 | if args.checkpoint is not None: 38 | model.load_state_dict(torch.load(args.checkpoint)) 39 | 40 | model.sample((16,channels,img_size,img_size), train=False) 41 | 42 | elif args.outlier_detection: 43 | in_loader, img_size, channels = pick_dataset(args.dataset, normalize=False, batch_size=args.batch_size, size=size) 44 | out_loader, _, _ = pick_dataset(args.out_dataset, normalize=False, batch_size=args.batch_size, size=img_size) 45 | model = PixelCNN(channels, args.hidden_channels) 46 | 47 | if args.checkpoint is not None: 48 | model.load_state_dict(torch.load(args.checkpoint)) 49 | 50 | model.outlier_detection(in_loader, out_loader) -------------------------------------------------------------------------------- /.env: -------------------------------------------------------------------------------- 1 | # Environment variables go here, can be read by `python-dotenv` package, and `os.getenv`: 2 | # 3 | # `src/generativezoo/config.py` 4 | # ---------------------------------------------------------------- 5 | # from os import getenv 6 | # from os.path import dirname, join 7 | # import dotenv 8 | # 9 | # project_dir = dirname(dirname(dirname(__file__))) 10 | # 11 | # # Load the environment variables from the `.env` file, overriding any system environment variables 12 | # env_path = join(project_dir, '.env') 13 | # dotenv.load_dotenv(env_path, override=True) 14 | # 15 | # # Load secrets from the `.secrets` file, overriding any system environment variables 16 | # secrets_path = join(project_dir, '.secrets') 17 | # load_dotenv(secrets_path, override=True) 18 | # 19 | # # Example variable 20 | # EXAMPLE_VARIABLE = getenv("EXAMPLE_VARIABLE") 21 | # 22 | # ---------------------------------------------------------------- 23 | # 24 | # DO NOT STORE SECRETS HERE! You should store secrets in a `.secrets` file, which is not versioned 25 | DOMAIN=localhost 26 | GUNICORN_WORKERS=1 27 | LOG_LEVEL=debug 28 | # For folder/file path environment variables, use relative paths. 29 | # Add environment variables for the `data` directories 30 | DIR_DATA=./data 31 | DIR_DATA_EXTERNAL=./data/external 32 | DIR_DATA_RAW=./data/raw 33 | DIR_DATA_INTERIM=./data/interim 34 | DIR_DATA_PROCESSED=./data/processed 35 | 36 | # Add environment variables for the `docs` directory 37 | DIR_DOCS=./docs 38 | 39 | # Add environment variables for the `notebooks` directory 40 | DIR_NOTEBOOKS=./notebooks 41 | 42 | # Add environment variables for the `reports` directory 43 | DIR_REPORTS=./reports 44 | DIR_FIGURES=./reports/figures 45 | 46 | # Add environment variables for the `models` directory 47 | DIR_MODELS=./models 48 | 49 | # Add environment variables for the `src` directories 50 | DIR_SRC=./src/generativezoo/ 51 | DIR_SRC_DATA=./src/generativezoo/data 52 | DIR_SRC_FEATURES=./src/generativezoo/features 53 | DIR_SRC_MODELS=./src/generativezoo/models 54 | DIR_SRC_VISUALISATION=./src/generativezoo/visualisation 55 | DIR_SRC_UTILS=./src/generativezoo/utils 56 | 57 | # Add environment variables for the `tests` directory 58 | DIR_TESTS=./tests 59 | -------------------------------------------------------------------------------- /src/generativezoo/FlowPP.py: -------------------------------------------------------------------------------- 1 | from models.NF.FlowPlusPlus import * 2 | from data.Dataloaders import * 3 | from utils.util import parse_args_FlowPP 4 | import wandb 5 | 6 | if __name__ == '__main__': 7 | 8 | args = parse_args_FlowPP() 9 | 10 | size = None 11 | 12 | if args.train: 13 | train_loader, input_size, channels = pick_dataset(args.dataset, 'train', args.batch_size, normalize=False, size=size, num_workers=args.num_workers) 14 | if not args.no_wandb: 15 | wandb.init(project='FlowPlusPlus', 16 | config={ 17 | 'dataset': args.dataset, 18 | 'batch_size': args.batch_size, 19 | 'n_epochs': args.n_epochs, 20 | 'warm_up': args.warm_up, 21 | 'lr': args.lr, 22 | 'grad_clip': args.grad_clip, 23 | 'num_blocks': args.num_blocks, 24 | 'num_components': args.num_components, 25 | 'num_channels': args.num_channels, 26 | 'use_attn': args.use_attn, 27 | 'num_dequant_blocks': args.num_dequant_blocks, 28 | 'drop_prob': args.drop_prob, 29 | }, 30 | name = 'FlowPlusPlus_{}'.format(args.dataset)) 31 | model = FlowPlusPlus(args, channels=channels, img_size=input_size) 32 | model.train_model(args, train_loader) 33 | wandb.finish() 34 | 35 | elif args.sample: 36 | _, input_size, channels = pick_dataset(args.dataset, 'train', args.batch_size, normalize=False, size=size) 37 | model = FlowPlusPlus(args, channels=channels, img_size=input_size) 38 | model.load_checkpoints(args) 39 | model.sample(16, False) 40 | 41 | elif args.outlier_detection: 42 | in_loader, input_size, channels = pick_dataset(args.dataset, 'val', args.batch_size, normalize=False, size=size) 43 | out_loader, _, _ = pick_dataset(args.dataset, 'val', args.batch_size, normalize=False, size=input_size) 44 | model = FlowPlusPlus(args, channels=channels, img_size=input_size) 45 | model.load_checkpoints(args) 46 | model.outlier_detection(in_loader, out_loader) -------------------------------------------------------------------------------- /src/generativezoo/RNVP.py: -------------------------------------------------------------------------------- 1 | from models.NF.RealNVP import RealNVP 2 | from data.Dataloaders import * 3 | from utils.util import parse_args_RealNVP 4 | import wandb 5 | 6 | if __name__ == '__main__': 7 | 8 | args = parse_args_RealNVP() 9 | 10 | size = None 11 | 12 | if args.train: 13 | dataloader, img_size, channels = pick_dataset(args.dataset, batch_size=args.batch_size, normalize = False, size=size, num_workers=args.num_workers) 14 | model = RealNVP(img_size=img_size, in_channels=channels, args=args) 15 | 16 | if not args.no_wandb: 17 | wandb.init(project='RealNVP', 18 | 19 | config = {"dataset": args.dataset, 20 | "num_scales": args.num_scales, 21 | "mid_channels": args.mid_channels, 22 | "num_blocks": args.num_blocks, 23 | "batch_size": args.batch_size, 24 | "lr": args.lr, 25 | "n_epochs": args.n_epochs, 26 | "img_size": img_size, 27 | "channels": channels}, 28 | 29 | name=f"RealNVP_{args.dataset}") 30 | 31 | model.train_model(dataloader, args) 32 | wandb.finish() 33 | 34 | elif args.sample: 35 | _, img_size, channels = pick_dataset(args.dataset, batch_size=1, normalize = False, size=size) 36 | model = RealNVP(img_size=img_size, in_channels=channels, args=args) 37 | 38 | if args.checkpoint is not None: 39 | model.load_state_dict(torch.load(args.checkpoint)) 40 | 41 | model.sample(16, train=False) 42 | 43 | elif args.outlier_detection: 44 | in_loader, img_size, channels = pick_dataset(args.dataset, mode = 'val', batch_size=args.batch_size, normalize = False, size=size) 45 | out_loader, _, _ = pick_dataset(args.out_dataset, mode = 'val', batch_size=args.batch_size, normalize = False, size=img_size) 46 | model = RealNVP(img_size=img_size, in_channels=channels, args=args) 47 | 48 | if args.checkpoint is not None: 49 | model.load_state_dict(torch.load(args.checkpoint)) 50 | 51 | model.outlier_detection(in_loader, out_loader) 52 | 53 | -------------------------------------------------------------------------------- /docs/Text2Img_LoRA.md: -------------------------------------------------------------------------------- 1 | # Stable Diffusion 2.1 Text-to-Image with LoRA 2 | 3 | ## Prepare the Dataset 4 | 5 | If you want to train the model in a custom dataset, two elements must be provided: the `images` and the `prompts`. The images must be provided in a folder organized as follows: 6 | 7 | ```bash 8 | ├── dataset 9 | │ ├── images 10 | │ │ ├── 11 | │ │ ├── 12 | │ │ ├── ... 13 | │ ├── train.jsonl 14 | ``` 15 | 16 | The file `train.jsonl` contains the prompts associated to each image and should be structured like the following example: 17 | 18 | ```json 19 | {"text": "a drawing of a green pokemon with red eyes", "image": "./../../../../data/processed/pokemons/images/0000.png"} 20 | {"text": "a green and yellow toy with a red nose", "image": "./../../../../data/processed/pokemons/images/0001.png"} 21 | ... 22 | ``` 23 | 24 | ## Accelerator Config 25 | 26 | Please configure the accelerator to match your system requirements by running: 27 | 28 | ```bash 29 | accelerate config 30 | ``` 31 | 32 | ## Train the model 33 | 34 | A [`script file`](./../scripts/Text2Img_Lora.sh) is provided with the commands required to train the model on a custom dataset. Several parameters should be configured, mainly: 35 | 36 | ```sh 37 | export OUTPUT_DIR="./../../../../models/Text2Img_Lora/pokemons" 38 | export DATASET_PATH="./../../../../data/processed/pokemons" 39 | ``` 40 | 41 | Several parameters in the training command can be tuned by the user, particularly the `--validation_prompt` which should reflect the use case of this training. 42 | 43 | ```sh 44 | accelerate launch --mixed_precision="fp16" Text2Img_Lora.py \ 45 | --pretrained_model_name_or_path=$MODEL_NAME \ 46 | --train_data_dir=$DATASET_PATH \ 47 | --dataloader_num_workers=8 \ 48 | --resolution=512 \ 49 | --center_crop \ 50 | --random_flip \ 51 | --train_batch_size=2 \ 52 | --gradient_accumulation_steps=1 \ 53 | --max_train_steps=15000 \ 54 | --learning_rate=1e-04 \ 55 | --max_grad_norm=1 \ 56 | --lr_scheduler="cosine" \ 57 | --lr_warmup_steps=0 \ 58 | --output_dir=${OUTPUT_DIR} \ 59 | --report_to=wandb \ 60 | --checkpointing_steps=500 \ 61 | --validation_prompt="A pokemon with blue wings." \ 62 | --seed=1337 \ 63 | --image_column="image" \ 64 | --caption_column="text" 65 | ``` 66 | 67 | ## Inference 68 | 69 | A Python script is provided to use the trained LoRA adapters: 70 | 71 | ```bash 72 | python Text2Img_LoRA.py --lora_model_path ./../../models/Text2Img_Lora/pokemons/checkpoint-60 73 | ``` -------------------------------------------------------------------------------- /src/generativezoo/DDPM.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from data.Dataloaders import * 3 | from models.DDPM.DDPM import * 4 | from utils.util import parse_args_DDPM 5 | import wandb 6 | 7 | if __name__ == '__main__': 8 | 9 | device = "cuda" if torch.cuda.is_available() else "cpu" 10 | args = parse_args_DDPM() 11 | normalize = True 12 | 13 | if args.train: 14 | dataloader, input_size, channels = pick_dataset(args.dataset, 'train', args.batch_size, normalize=normalize, size=args.size, num_workers=args.num_workers) 15 | model = DDPM(args, channels=channels, image_size=input_size) 16 | model.train_model(dataloader) 17 | wandb.finish() 18 | 19 | elif args.sample: 20 | _, input_size, channels = pick_dataset(args.dataset, 'val', args.batch_size, normalize=normalize, size=args.size) 21 | model = DDPM(args, channels=channels, image_size=input_size) 22 | model.model.load_state_dict(torch.load(args.checkpoint, weights_only=False)) 23 | model.sample(args.num_samples) 24 | 25 | elif args.inpaint: 26 | dataloader, input_size, channels = pick_dataset(args.dataset, 'val', args.batch_size, normalize=normalize, size=args.size) 27 | model = DDPM(args, channels=channels, image_size=input_size) 28 | model.model.load_state_dict(torch.load(args.checkpoint, weights_only=False)) 29 | model.inpaint(dataloader) 30 | 31 | elif args.outlier_detection: 32 | dataloader_a, input_size, channels = pick_dataset(args.dataset, 'val', args.batch_size, normalize=normalize, size=args.size) 33 | model = DDPM(args, channels=channels, image_size=input_size) 34 | model.model.load_state_dict(torch.load(args.checkpoint, weights_only=False)) 35 | dataloader_b, input_size_b, channels_b = pick_dataset(args.out_dataset, 'val', args.batch_size, normalize=normalize, good = False, size=input_size) 36 | model.outlier_detection(dataloader_a,dataloader_b, args.dataset, args.out_dataset) 37 | 38 | elif args.fid: 39 | _, input_size, channels = pick_dataset(args.dataset, 'val', args.batch_size, normalize=normalize, size=args.size) 40 | model = DDPM(args, channels=channels, image_size=input_size) 41 | if args.checkpoint is not None: 42 | model.model.load_state_dict(torch.load(args.checkpoint, weights_only=False)) 43 | model.fid_sample(args.batch_size) 44 | 45 | else: 46 | raise ValueError('Please specify at least one of the following: train, sample, outlier_detection') -------------------------------------------------------------------------------- /src/generativezoo/DCGAN.py: -------------------------------------------------------------------------------- 1 | from models.GAN.DCGAN import * 2 | from data.Dataloaders import * 3 | from utils.util import parse_args_DCGAN 4 | import torch 5 | import wandb 6 | 7 | if __name__ == '__main__': 8 | args = parse_args_DCGAN() 9 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 10 | 11 | size = None 12 | 13 | if args.train: 14 | if not args.no_wandb: 15 | wandb.init(project='DCGAN', 16 | config={ 17 | 'dataset': args.dataset, 18 | 'batch_size': args.batch_size, 19 | 'n_epochs': args.n_epochs, 20 | 'latent_dim': args.latent_dim, 21 | 'd': args.d, 22 | 'lrg': args.lrg, 23 | 'lrd': args.lrd, 24 | 'beta1': args.beta1, 25 | 'beta2': args.beta2, 26 | 'sample_and_save_freq': args.sample_and_save_freq 27 | }, 28 | name = 'DCGAN_{}'.format(args.dataset)) 29 | train_dataloader, input_size, channels = pick_dataset(dataset_name = args.dataset, batch_size=args.batch_size, normalize = True, size = size, num_workers=args.num_workers) 30 | model = DCGAN(args, channels, input_size) 31 | model.train_model(train_dataloader) 32 | 33 | elif args.sample: 34 | _, input_size, channels = pick_dataset(dataset_name = args.dataset, batch_size=1, normalize = True, size = size) 35 | model = Generator(latent_dim = args.latent_dim, channels = channels, d=args.d, imgSize=input_size).to(device) 36 | model.load_state_dict(torch.load(args.checkpoint)) 37 | model.eval() 38 | model.sample(n_samples = args.n_samples, device = device) 39 | 40 | elif args.outlier_detection: 41 | in_loader, input_size, channels = pick_dataset(dataset_name = args.dataset, batch_size=args.batch_size, normalize = True, size = size, mode='val') 42 | out_loader, _, _ = pick_dataset(dataset_name = args.out_dataset, batch_size=args.batch_size, normalize = True, size = input_size, mode='val') 43 | model = Discriminator(channels=channels, d=args.d, imgSize=input_size).to(device) 44 | model.load_state_dict(torch.load(args.discriminator_checkpoint)) 45 | model.eval() 46 | model.outlier_detection(in_loader, out_loader, display=True, device=device) 47 | 48 | else: 49 | raise Exception('Please specify either --train, --sample or --outlier_detection. For more information use --help.') -------------------------------------------------------------------------------- /src/generativezoo/VanVAE.py: -------------------------------------------------------------------------------- 1 | from models.VAE.VanillaVAE import * 2 | from data.Dataloaders import * 3 | import torch 4 | from utils.util import parse_args_VanillaVAE 5 | import wandb 6 | 7 | if __name__ == '__main__': 8 | 9 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 10 | 11 | args = parse_args_VanillaVAE() 12 | 13 | size = None 14 | 15 | if args.train: 16 | # train dataloader 17 | train_loader, in_shape, in_channels = pick_dataset(args.dataset, batch_size = args.batch_size, normalize=True, size = size, num_workers=args.num_workers) 18 | if not args.no_wandb: 19 | wandb.init(project='VAE', 20 | 21 | config={ 22 | 'dataset': args.dataset, 23 | 'batch_size': args.batch_size, 24 | 'n_epochs': args.n_epochs, 25 | 'lr': args.lr, 26 | 'latent_dim': args.latent_dim, 27 | 'hidden_dims': args.hidden_dims, 28 | 'input_size': in_shape, 29 | 'channels': in_channels, 30 | 'loss_type': args.loss_type, 31 | 'kld_weight': args.kld_weight 32 | }, 33 | 34 | name = 'VAE_{}'.format(args.dataset)) 35 | # create model 36 | model = VanillaVAE(input_shape=in_shape, input_channels=in_channels,args=args) 37 | # train model 38 | model.train_model(train_loader, args.n_epochs) 39 | 40 | elif args.sample: 41 | _, in_shape, in_channels = pick_dataset(args.dataset, batch_size = args.batch_size, normalize=True, size = size) 42 | model = VanillaVAE(input_shape=in_shape, input_channels=in_channels,args=args) 43 | model.load_state_dict(torch.load(args.checkpoint)) 44 | model.sample(title="Sample", train = False) 45 | 46 | elif args.outlier_detection: 47 | in_loader, in_shape, in_channels = pick_dataset(args.dataset, batch_size = args.batch_size, normalize=True, size = size, mode='val') 48 | out_loader, _, _ = pick_dataset(args.out_dataset, batch_size = args.batch_size, normalize=True, size = in_shape, mode='val') 49 | model = VanillaVAE(input_shape=in_shape, input_channels=in_channels,args=args) 50 | if args.checkpoint is not None: 51 | model.load_state_dict(torch.load(args.checkpoint)) 52 | model.outlier_detection(in_loader, out_loader) 53 | else: 54 | raise ValueError("Invalid mode. Please specify train or sample") -------------------------------------------------------------------------------- /src/generativezoo/CycGAN.py: -------------------------------------------------------------------------------- 1 | from models.GAN.CycleGAN import * 2 | from data.CycleGAN_Dataloaders import * 3 | from config import data_raw_dir 4 | import torch 5 | import wandb 6 | from utils.util import parse_args_CycleGAN 7 | 8 | if __name__ == '__main__': 9 | 10 | args = parse_args_CycleGAN() 11 | 12 | device = "cuda" if torch.cuda.is_available() else "cpu" 13 | 14 | if args.train: 15 | if not args.no_wandb: 16 | wandb.init(project='CycleGAN', 17 | config={ 18 | 'dataset': args.dataset, 19 | 'batch_size': args.batch_size, 20 | 'n_epochs': args.n_epochs, 21 | 'lr': args.lr, 22 | 'decay': args.decay, 23 | 'input_size': args.input_size, 24 | 'in_channels': args.in_channels, 25 | 'out_channels': args.out_channels, 26 | 'sample_and_save_freq': args.sample_and_save_freq 27 | }, 28 | name = 'CycleGAN_{}'.format(args.dataset)) 29 | 30 | train_dataloader_A = get_horse2zebra_dataloader(data_raw_dir, args.dataset, args.batch_size, True, 'A', args.input_size) 31 | train_dataloader_B = get_horse2zebra_dataloader(data_raw_dir, args.dataset, args.batch_size, True, 'B', args.input_size) 32 | test_dataloader_A = get_horse2zebra_dataloader(data_raw_dir, args.dataset, args.batch_size, False, 'A', args.input_size) 33 | test_dataloader_B = get_horse2zebra_dataloader(data_raw_dir, args.dataset, args.batch_size, False, 'B', args.input_size) 34 | model = CycleGAN(args.in_channels, args.out_channels, args) 35 | model.train_model(train_dataloader_A, train_dataloader_B, test_dataloader_A, test_dataloader_B) 36 | 37 | elif args.test: 38 | test_dataloader_A = get_horse2zebra_dataloader(data_raw_dir, args.dataset, args.batch_size, False, 'A', args.input_size) 39 | test_dataloader_B = get_horse2zebra_dataloader(data_raw_dir, args.dataset, args.batch_size, False, 'B', args.input_size) 40 | model_AB = Generator(args.in_channels, args.out_channels).to(device) 41 | model_BA = Generator(args.in_channels, args.out_channels).to(device) 42 | model_AB.load_state_dict(torch.load(args.checkpoint_A)) 43 | model_BA.load_state_dict(torch.load(args.checkpoint_B)) 44 | 45 | model_AB.sample(test_dataloader_A, device) 46 | model_BA.sample(test_dataloader_B, device) 47 | 48 | #test(args.checkpoint_A, args.checkpoint_B, test_dataloader_A, test_dataloader_B, args.in_channels, args.out_channels, device) 49 | 50 | else: 51 | raise Exception('Please specify either --train or --test') -------------------------------------------------------------------------------- /docs/LOADING_ENV_VARIABLES.md: -------------------------------------------------------------------------------- 1 | # Loading environment variables 2 | 3 | [We use `python-dotenv` to load environment variables][python-dotenv], as these are only loaded when 4 | inside the project folder. This can prevent accidental conflicts with identically named 5 | variables. Alternatively you can use [`direnv` to load environment variables][direnv] if 6 | you meet [certain conditions](#installing-direnv). 7 | 8 | ## Using `python-dotenv` 9 | 10 | To load the environment variables, first make sure you have 11 | python-dotenv install, and [make sure you have a `.secrets` file to store 12 | secrets and credentials](#storing-secrets-and-credentials). Then to load in the 13 | environment variables into a python script see instructions in `.env` file. 14 | 15 | ## Using `direnv` 16 | 17 | To load the environment variables, first [follow the `direnv` installation 18 | instructions](#installing-direnv), and [make sure you have a `.secrets` file to store 19 | secrets and credentials](#storing-secrets-and-credentials). Then: 20 | 21 | 1. Open your terminal; 22 | 2. Install `direnv`. See instructions below. 23 | 3. Navigate to the project folder; and 24 | - You should see the following message: 25 | ```shell 26 | direnv: error .envrc is blocked. Run `direnv allow` to approve its content. 27 | ``` 28 | 4. Allow `direnv`. 29 | ```shell 30 | direnv allow 31 | ``` 32 | 33 | You only need to do this once, and again each time `.envrc` and `.secrets` are modified. 34 | 35 | ### Installing `direnv` 36 | 37 | 1. Open your terminal; 38 | 2. Install [`direnv`](https://direnv.net/docs/installation.html); 39 | 3. Add the shell hooks to your `.bash_profile`; 40 | ```shell 41 | echo 'eval "$(direnv hook bash)"' >> ~/.bash_profile 42 | ``` 43 | 4. Check that the shell hooks have been added correctly; and 44 | ```shell 45 | cat ~/.bash_profile 46 | ``` 47 | - This should display `eval "$(direnv hook bash)"` 48 | 5. Restart your terminal. 49 | 50 | ## Storing secrets and credentials 51 | 52 | Secrets and credentials must be stored in the `.secrets` file. This file is not 53 | version-controlled, so no secrets should be committed to GitHub. 54 | 55 | Open this new `.secrets` file using your preferred text editor, and add any secrets as 56 | environmental variables. For example, to add a JSON credentials file: 57 | 58 | ```shell 59 | APPLICATION_CREDENTIALS="path/to/credentials.json" 60 | ``` 61 | 62 | Once complete, make sure the `.secrets` file has the following line uncommented out: 63 | 64 | ```shell 65 | source_env ".secrets" 66 | ``` 67 | 68 | This ensures [`direnv`][direnv] loads the `.secrets` file using `.envrc` without 69 | version-controlling `.secrets`. 70 | 71 | [direnv]: https://direnv.net/ 72 | [python-dotenv]: https://saurabh-kumar.com/python-dotenv/ 73 | -------------------------------------------------------------------------------- /src/generativezoo/WGAN.py: -------------------------------------------------------------------------------- 1 | from models.GAN.WGAN import * 2 | import torch 3 | from data.Dataloaders import * 4 | import wandb 5 | import os 6 | from torchvision import transforms 7 | from torch.utils.data import DataLoader, Dataset 8 | from utils.util import parse_args_WassersteinGAN 9 | 10 | if __name__ == '__main__': 11 | 12 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 13 | args = parse_args_WassersteinGAN() 14 | 15 | size = None 16 | 17 | if args.train: 18 | if not args.no_wandb: 19 | wandb.init(project="WGAN", 20 | config={ 21 | "dataset": args.dataset, 22 | "batch_size": args.batch_size, 23 | "n_epochs": args.n_epochs, 24 | "latent_dim": args.latent_dim, 25 | "d": args.d, 26 | "lrg": args.lrg, 27 | "lrd": args.lrd, 28 | "beta1": args.beta1, 29 | "beta2": args.beta2, 30 | "n_critic": args.n_critic, 31 | "gp_weight": args.gp_weight 32 | }, 33 | name=f"WGAN_{args.dataset}") 34 | 35 | train_loader, input_size, channels = pick_dataset(dataset_name=args.dataset, batch_size=args.batch_size, normalize=False, size=size, num_workers=args.num_workers) 36 | model = WGAN(args=args, imgSize=input_size, channels=channels) 37 | model.train_model(train_loader) 38 | wandb.finish() 39 | 40 | elif args.sample: 41 | _, input_size, channels = pick_dataset(dataset_name=args.dataset, batch_size=1, normalize=False, size=size) 42 | model = Generator(latent_dim=args.latent_dim, channels=channels, d=args.d, imgSize=input_size).to(device) 43 | model.load_state_dict(torch.load(args.checkpoint)) 44 | model.sample(n_samples=args.n_samples, device=device) 45 | 46 | elif args.outlier_detection: 47 | 48 | in_loader, input_size, channels = pick_dataset(dataset_name=args.dataset, batch_size=args.batch_size, normalize=False, size=size, mode='val') 49 | out_loader, _, _ = pick_dataset(dataset_name=args.out_dataset, batch_size=args.batch_size, normalize=False, size=input_size, mode='val') 50 | 51 | model = WGAN(batch_size = args.batch_size, latent_dim=args.latent_dim, d=args.d, lrg=args.lrg, lrd=args.lrd, beta1=args.beta1, beta2=args.beta2, gp_weight=args.gp_weight, dataset=args.dataset, n_epochs=args.n_epochs, n_critic=args.n_critic, sample_and_save_freq=args.sample_and_save_freq, imgSize=input_size, channels=channels) 52 | model.D.load_state_dict(torch.load(args.discriminator_checkpoint)) 53 | model.outlier_detection(in_loader, out_loader, display=True) -------------------------------------------------------------------------------- /src/generativezoo/NCSNv2.py: -------------------------------------------------------------------------------- 1 | from models.SM.NCSNv2 import * 2 | from utils.util import parse_args_NCSNv2 3 | import torch 4 | from data.Dataloaders import * 5 | import wandb 6 | 7 | 8 | if __name__ == '__main__': 9 | 10 | device = "cuda" if torch.cuda.is_available() else "cpu" 11 | 12 | args = parse_args_NCSNv2() 13 | 14 | size = None 15 | 16 | if args.train: 17 | train_loader, input_size, channels = pick_dataset(dataset_name=args.dataset, batch_size=args.batch_size, normalize=False, size=size, num_workers=args.num_workers) 18 | if not args.no_wandb: 19 | wandb.init(project="NCSNv2", 20 | 21 | config = { 22 | "dataset": args.dataset, 23 | "batch_size": args.batch_size, 24 | "n_steps": args.n_steps, 25 | "lr": args.lr, 26 | "n_epochs": args.n_epochs, 27 | "beta1": args.beta1, 28 | "beta2": args.beta2, 29 | "weight_decay": args.weight_decay, 30 | "nf": args.nf, 31 | "snr": args.snr, 32 | "probability_flow": args.probability_flow, 33 | "predictor": args.predictor, 34 | "corrector": args.corrector, 35 | "noise_removal": args.noise_removal, 36 | "sigma_max": args.sigma_max, 37 | "sigma_min": args.sigma_min, 38 | "num_scales": args.num_scales, 39 | "normalization": args.normalization, 40 | "continuous": args.continuous, 41 | "reduce_mean": args.reduce_mean, 42 | "likelihood_weighting": args.likelihood_weighting, 43 | "act": args.act, 44 | }, 45 | 46 | name=f"NCSNv2_{args.dataset}" 47 | 48 | ) 49 | 50 | model = NCSNv2(input_size, channels, args) 51 | model.train_model(train_loader, args) 52 | wandb.finish() 53 | 54 | elif args.sample: 55 | _, input_size, channels = pick_dataset(dataset_name=args.dataset, batch_size=args.batch_size, normalize=False, size=size) 56 | model = NCSNv2(input_size, channels, args) 57 | model.load_checkpoints(args.checkpoint) 58 | model.sample(args, False) 59 | 60 | else: 61 | raise ValueError("Invalid mode, choose either train, sample or outlier_detection.") -------------------------------------------------------------------------------- /src/generativezoo/GLOW.py: -------------------------------------------------------------------------------- 1 | from models.NF.Glow import * 2 | from data.Dataloaders import * 3 | from utils.util import parse_args_Glow 4 | import wandb 5 | 6 | 7 | if __name__ == '__main__': 8 | 9 | args = parse_args_Glow() 10 | normalize = False 11 | 12 | size = None 13 | 14 | if args.train: 15 | if not args.no_wandb: 16 | wandb.init(project='GLOW', 17 | config={ 18 | "batch_size": args.batch_size, 19 | "lr": args.lr, 20 | "n_epochs": args.n_epochs, 21 | "dataset": args.dataset, 22 | "hidden_channels": args.hidden_channels, 23 | "K": args.K, 24 | "L": args.L, 25 | "actnorm_scale": args.actnorm_scale, 26 | "flow_permutation": args.flow_permutation, 27 | "flow_coupling": args.flow_coupling, 28 | "LU_decomposed": args.LU_decomposed, 29 | "learn_top": args.learn_top, 30 | "y_condition": args.y_condition, 31 | "num_classes": args.num_classes, 32 | "n_bits": args.n_bits, 33 | }, 34 | 35 | name = 'GLOW_{}'.format(args.dataset)) 36 | 37 | train_loader, input_shape, channels = pick_dataset(args.dataset, batch_size=args.batch_size, normalize=normalize, size=size, num_workers=args.num_workers) 38 | model = Glow(image_shape = (input_shape,input_shape,channels), hidden_channels = args.hidden_channels, args=args) 39 | model.train_model(train_loader, args) 40 | 41 | elif args.sample: 42 | _, input_shape, channels = pick_dataset(args.dataset, batch_size=args.batch_size, normalize=normalize, size=size, num_workers=0) 43 | model = Glow(image_shape = (input_shape,input_shape,channels), hidden_channels = args.hidden_channels, args=args) 44 | model.load_checkpoint(args) 45 | model.sample(train=False) 46 | 47 | elif args.outlier_detection: 48 | in_loader, input_shape, channels = pick_dataset(args.dataset, batch_size=args.batch_size, normalize=normalize, size=size, num_workers=0, mode='val') 49 | out_loader, _, _ = pick_dataset(args.out_dataset, batch_size=args.batch_size, normalize=normalize, size=input_shape, num_workers=0, mode='val') 50 | model = Glow(image_shape = (input_shape,input_shape,channels), hidden_channels = args.hidden_channels, args=args) 51 | model.load_checkpoint(args) 52 | model.outlier_detection(in_loader, out_loader) 53 | 54 | else: 55 | raise ValueError("Invalid mode. Please specify train or sample") 56 | -------------------------------------------------------------------------------- /docs/Text2Img_ControlNet.md: -------------------------------------------------------------------------------- 1 | # ControlNet 2 | 3 | ## Prepare the Dataset 4 | 5 | If you want to train the model in a custom dataset, three elements must be provided: the `images`, the `conditioning images` and the `prompts`. The images must be provided in a folder organized as follows: 6 | 7 | ```bash 8 | ├── dataset 9 | │ ├── images 10 | │ │ ├── 11 | │ │ ├── 12 | │ │ ├── ... 13 | │ ├── conditioning_images 14 | │ │ ├── 15 | │ │ ├── 16 | │ │ ├── ... 17 | │ ├── train.jsonl 18 | ``` 19 | 20 | The file `train.jsonl` contains the prompts associated to each image and conditioning image and should be structured like the following example: 21 | 22 | ```json 23 | {"text": "a drawing of a green pokemon with red eyes", "image": "./../../../../data/processed/pokemons/images/0000.png", "conditioning_image": "./../../../../data/processed/pokemons/conditioning_images/0000_mask.png"} 24 | {"text": "a green and yellow toy with a red nose", "image": "./../../../../data/processed/pokemons/images/0001.png", "conditioning_image": "./../../../../data/processed/pokemons/conditioning_images/0001_mask.png"} 25 | ... 26 | ``` 27 | 28 | ## Accelerator Config 29 | 30 | Please configure the accelerator to match your system requirements by running: 31 | 32 | ```bash 33 | accelerate config 34 | ``` 35 | 36 | ## Train the model 37 | 38 | A [`script file`](./../scripts/Text2Img_Controlnet.sh) is provided with the commands required to train the model on a custom dataset. Several parameters should be configured, mainly: 39 | 40 | ```sh 41 | export OUTPUT_DIR="./../../../../models/Text2Img_Controlnet/pokemons" 42 | export DATASET_PATH="./../../../../data/processed/pokemons" 43 | ``` 44 | 45 | Several parameters in the training command can be tuned by the user, particularly the `--validation_prompt` and `--validation_image` which should reflect the use case of this training. 46 | 47 | ```sh 48 | accelerate launch --mixed_precision="fp16" Text2Img_Controlnet.py \ 49 | --pretrained_model_name_or_path=$MODEL_DIR \ 50 | --output_dir=$OUTPUT_DIR \ 51 | --train_data_dir=$DATASET_PATH \ 52 | --resolution=512 \ 53 | --learning_rate=1e-5 \ 54 | --max_train_steps=10000 \ 55 | --checkpointing_steps=20 \ 56 | --validation_image "./../../../../data/processed/pokemons/conditioning_images/0003_mask.png" \ 57 | --validation_prompt "red circle pokemon with white dots" \ 58 | --train_batch_size=1 \ 59 | --report_to=wandb \ 60 | --image_column="image" \ 61 | --caption_column="text" \ 62 | --conditioning_image_column='conditioning_image' \ 63 | --gradient_accumulation_steps=1 \ 64 | ``` 65 | 66 | For GPUs with less VRAM, you might consider using some of the following options to reduce memory usage: 67 | 68 | ```sh 69 | --use_8bit_adam \ 70 | --gradient_checkpointing \ 71 | --set_grads_to_none \ 72 | ``` 73 | 74 | ## Inference 75 | 76 | A Python script is provided to use the trained ControlNet adapters: 77 | 78 | ```bash 79 | python Text2Img_ControlNet.py --cnet_model_path ./../../models/Text2Img_ControlNet/pokemons/checkpoint-200/controlnet --cond_image_path ./../../data/processed/pokemons/conditioning_images/0003_mask.png 80 | ``` -------------------------------------------------------------------------------- /src/generativezoo/AdvVAE.py: -------------------------------------------------------------------------------- 1 | from models.GAN.AdversarialVAE import * 2 | from data.Dataloaders import * 3 | from utils.util import parse_args_AdversarialVAE 4 | import torch 5 | import wandb 6 | 7 | if __name__ == '__main__': 8 | 9 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 10 | 11 | args = parse_args_AdversarialVAE() 12 | 13 | size = args.size 14 | 15 | if args.train: 16 | if not args.no_wandb: 17 | wandb.init(project='AdversarialVAE', 18 | config={ 19 | 'dataset': args.dataset, 20 | 'batch_size': args.batch_size, 21 | 'n_epochs': args.n_epochs, 22 | 'latent_dim': args.latent_dim, 23 | 'hidden_dims': args.hidden_dims, 24 | 'lr': args.lr, 25 | 'gen_weight': args.gen_weight, 26 | 'recon_weight': args.recon_weight, 27 | 'sample_and_save_frequency': args.sample_and_save_frequency, 28 | 'kld_weight': args.kld_weight, 29 | }, 30 | name = 'AdversarialVAE_{}'.format(args.dataset)) 31 | 32 | train_loader, input_size, channels = pick_dataset(dataset_name=args.dataset, batch_size=args.batch_size, normalize=True, num_workers=args.num_workers, mode='train', size=size, n_patches=args.patches) 33 | model = AdversarialVAE(input_shape = input_size, input_channels=channels, args=args) 34 | model.train_model(train_loader) 35 | 36 | elif args.test: 37 | test_loader, input_size, channels = pick_dataset(dataset_name=args.dataset, batch_size=args.batch_size, normalize=True, mode='val', size=size) 38 | model = AdversarialVAE(input_shape = input_size, input_channels=channels, args=args) 39 | model.load_state_dict(torch.load(args.checkpoint)) 40 | model.create_validation_grid(test_loader) 41 | elif args.sample: 42 | _, input_size, channels = pick_dataset(dataset_name=args.dataset, batch_size=args.batch_size, normalize=True, mode='val', size=size) 43 | model = AdversarialVAE(input_shape = input_size, input_channels=channels, args=args) 44 | model.load_state_dict(torch.load(args.checkpoint)) 45 | model.create_grid() 46 | elif args.outlier_detection: 47 | in_loader, input_size, channels = pick_dataset(dataset_name=args.dataset, batch_size=args.batch_size, normalize=True, mode='val', size=size) 48 | out_loader, _, _ = pick_dataset(dataset_name=args.out_dataset, batch_size=args.batch_size, normalize=True, mode='val', size=input_size) 49 | model = AdversarialVAE(input_shape = input_size, input_channels=channels, args=args) 50 | if args.checkpoint is not None: 51 | model.vae.load_state_dict(torch.load(args.checkpoint)) 52 | if args.discriminator_checkpoint is not None: 53 | model.discriminator.load_state_dict(torch.load(args.discriminator_checkpoint)) 54 | model.eval() 55 | model.outlier_detection(in_loader, out_loader) 56 | else: 57 | Exception("Invalid mode. Set --train, --test or --sample") -------------------------------------------------------------------------------- /docs/InstructPix2Pix.md: -------------------------------------------------------------------------------- 1 | # InstructPix2Pix 2 | 3 | ## Prepare the Dataset 4 | 5 | If you want to train the model in a custom dataset, three elements must be provided: the `original images`, the `edited images` and the `edit prompts`. The images must be provided in a folder organized as follows: 6 | 7 | ```bash 8 | ├── dataset 9 | │ ├── original_images 10 | │ │ ├── 11 | │ │ ├── 12 | │ │ ├── ... 13 | │ ├── edited_images 14 | │ │ ├── 15 | │ │ ├── 16 | │ │ ├── ... 17 | │ ├── train.jsonl 18 | ``` 19 | 20 | The file `train.jsonl` contains the edit prompts associated to each original image and edited image and should be structured like the following example: 21 | 22 | ```json 23 | {"edit_prompt": "a drawing of a green pokemon with red eyes", "original_image": "./../../../../data/processed/pokemons/images/0000.png", "edited_image": "./../../../../data/processed/pokemons/conditioning_images/0000_mask.png"} 24 | {"edit_prompt": "a green and yellow toy with a red nose", "original_image": "./../../../../data/processed/pokemons/images/0001.png", "edited_image": "./../../../../data/processed/pokemons/conditioning_images/0001_mask.png"} 25 | ... 26 | ``` 27 | 28 | ## Accelerator Config 29 | 30 | Please configure the accelerator to match your system requirements by running: 31 | 32 | ```bash 33 | accelerate config 34 | ``` 35 | 36 | ## Train the model 37 | 38 | A [`script file`](./../scripts/InstructPix2Pix.sh) is provided with the commands required to train the model on a custom dataset. Several parameters should be configured, mainly: 39 | 40 | ```sh 41 | export OUTPUT_DIR="./../../../../models/InstructPix2Pix/pokemons" 42 | export DATASET_PATH="./../../../../data/processed/pokemons" 43 | ``` 44 | 45 | Several parameters in the training command can be tuned by the user, particularly the `--validation_prompt` and `--validation_image` which should reflect the use case of this training. 46 | 47 | ```sh 48 | accelerate launch --mixed_precision="fp16" train_InstructPix2Pix.py \ 49 | --pretrained_model_name_or_path=$MODEL_NAME \ 50 | --train_data_dir=$DATASET_PATH \ 51 | --output_dir=$OUTPUT_DIR \ 52 | --resolution=512 --random_flip \ 53 | --train_batch_size=4 --gradient_accumulation_steps=1 --gradient_checkpointing \ 54 | --max_train_steps=15000 \ 55 | --checkpointing_steps=5000 \ 56 | --learning_rate=5e-05 --max_grad_norm=1 --lr_warmup_steps=0 \ 57 | --conditioning_dropout_prob=0.05 \ 58 | --validation_image "./../../../../data/processed/pokemons/conditioning_images/0003_mask.png" \ 59 | --validation_prompt "red circle pokemon with white dots" \ 60 | --seed=42 \ 61 | --report_to=wandb \ 62 | --original_image_colum="original_image" \ 63 | --edited_image_column="edited_image" \ 64 | --edit_prompt_column="edit_prompt" 65 | ``` 66 | 67 | For GPUs with less VRAM, you might consider using some of the following options to reduce memory usage: 68 | 69 | ```sh 70 | --use_8bit_adam \ 71 | --gradient_checkpointing \ 72 | --set_grads_to_none \ 73 | ``` 74 | 75 | ## Inference 76 | 77 | A Python script is provided to use the trained ControlNet adapters: 78 | 79 | ```bash 80 | python InstructPix2Pix.py --pix2pix_model ./../../models/InstructPix2Pix/checkpoint-20/unet --image_path ./../../data/processed/pokemons/conditioning_images/0003_mask.png 81 | ``` -------------------------------------------------------------------------------- /docs/HierarchicalVAE.md: -------------------------------------------------------------------------------- 1 | # Hierarchical Variational Autoencoders (Hierarchical VAEs) 2 | 3 | NVAE introduced a deep hierarchical VAE designed for image generation through depth-wise separable convolutions and batch normalisation. NVAE has a residual parameterization for normal distributions and uses spectral regularisation to stabilise training. 4 | 5 | ## Parameters 6 | 7 | | Argument | Description | Default | Choices | 8 | |--------------------------|-----------------------------------------|-----------------|------------------------------------------------------------------------------------------| 9 | | `--train` | Train model | `False` | | 10 | | `--sample` | Sample model | `False` | | 11 | | `--dataset` | Dataset name | `mnist` | `mnist`, `cifar10`, `fashionmnist`, `chestmnist`, `octmnist`, `tissuemnist`, `pneumoniamnist`, `svhn`, `tinyimagenet`, `cifar100`, `places365`, `dtd`, `imagenet` | 12 | | `--no_wandb` | Disable Wandb | `False` | | 13 | | `--batch_size` | Batch size | `256` | | 14 | | `--n_epochs` | Number of epochs | `100` | | 15 | | `--lr` | Learning rate | `0.01` | | 16 | | `--latent_dim` | Latent dimension | `512` | | 17 | | `--checkpoint` | Checkpoint path | `None` | | 18 | | `--sample_and_save_freq` | Sample and save frequency | `5` | | 19 | | `--num_workers` | Number of workers for Dataloader | `0` | | 20 | 21 | You can find out more about the parameters by checking [`util.py`](./../src/generativezoo/utils/util.py) or by running the following command on the example script: 22 | 23 | python HVAE.py --help 24 | 25 | ## Training 26 | 27 | HVAE can be trained similarly to other models in the Zoo: 28 | 29 | python HVAE.py --train --dataset svhn 30 | 31 | ## Sampling 32 | 33 | For sampling you must provide the HVAE checkpoint: 34 | 35 | python HVAE.py --sample --dataset svhn --checkpoint ./../../models/HierarchicalVAE/HVAE_svhn.pt -------------------------------------------------------------------------------- /docs/CycleGAN.md: -------------------------------------------------------------------------------- 1 | # Cycle Generative Adversarial Network (CycleGAN) 2 | 3 | CycleGAN is a type of GAN designed for unsupervised image-to-image translation. Unlike traditional methods that require paired data for training, CycleGAN learns to translate images from one domain to another in the absence of paired examples. It accomplishes this by simultaneously training two generators and two discriminators in an adversarial manner. 4 | 5 | ## Parameters 6 | 7 | | Parameter | Description | Default | Choices | 8 | |-----------------------|-------------------------------------------------|---------|---------| 9 | | `--train` | train model | `False` | | 10 | | `--test` | test model | `False` | | 11 | | `--batch_size` | batch size | `1` | | 12 | | `--n_epochs` | number of epochs | `200` | | 13 | | `--lr` | learning rate | `0.0002`| | 14 | | `--decay` | epoch to start linearly decaying the learning rate to 0 | `100` | | 15 | | `--sample_and_save_freq` | sample and save frequency | `5` | | 16 | | `--dataset` | dataset name | `'horse2zebra'` | `'horse2zebra'` | 17 | | `--checkpoint_A` | checkpoint A path | `None` | | 18 | | `--checkpoint_B` | checkpoint B path | `None` | | 19 | | `--input_size` | input size | `128` | | 20 | | `--in_channels` | in channels | `3` | | 21 | | `--out_channels` | out channels | `3` | | 22 | | `--no_wandb` | Disable Wandb | `False` | | 23 | 24 | You can find out more about the parameters by checking [`util.py`](./../src/generativezoo/utils/util.py) or by running the following command on the example script: 25 | 26 | python CycGAN.py --help 27 | 28 | 29 | ## Training 30 | 31 | CycleGAN introduces cycle consistency loss, which enforces the translated images to revert back to their original form when translated back to the original domain. This concept enables CycleGAN to learn mappings between domains in a self-supervised manner, making it particularly effective for tasks such as style transfer, object transfiguration, and domain adaptation without paired data. It couples this loss with an identity los, a reconstruction loss and the adversarial loss to train the generators. 32 | 33 | python CycGAN.py --train --dataset horse2zebra 34 | 35 | ## Test 36 | 37 | Unfortunately, it is not possible to sample from a CycleGAN, we can only perform image-to-image translation. Therefore, if we input an image of domain A to the Generator that learnt how to map A->B, we can obtain a translated image. 38 | 39 | python CicGAN.py --test --dataset horse2zebra --checkpoint_A ./../../models/CycleGAN/CycGAN_horse2zebra_AB.pt --checkpoint_B ./../../models/CycleGAN/CycGAN_horse2zebra_BA.pt -------------------------------------------------------------------------------- /src/generativezoo/PresGAN.py: -------------------------------------------------------------------------------- 1 | from models.GAN.PrescribedGAN import * 2 | from data.Dataloaders import * 3 | from utils.util import parse_args_PresGAN 4 | import torch 5 | import wandb 6 | 7 | if __name__ == '__main__': 8 | args = parse_args_PresGAN() 9 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 10 | 11 | size = None 12 | 13 | if args.train: 14 | train_dataloader, input_size, channels = pick_dataset(dataset_name = args.dataset, batch_size=args.batch_size, normalize = True, size=size, num_workers=args.num_workers) 15 | if not args.no_wandb: 16 | wandb.init(project="PresGAN", 17 | 18 | config = { 19 | "dataset": args.dataset, 20 | "batch_size": args.batch_size, 21 | "nz": args.nz, 22 | "ngf": args.ngf, 23 | "ndf": args.ndf, 24 | "lrD": args.lrD, 25 | "lrG": args.lrG, 26 | "beta1": args.beta1, 27 | "n_epochs": args.n_epochs, 28 | "sigma_lr": args.sigma_lr, 29 | "num_gen_images": args.num_gen_images, 30 | "restrict_sigma": args.restrict_sigma, 31 | "sigma_min": args.sigma_min, 32 | "sigma_max": args.sigma_max, 33 | "stepsize_num": args.stepsize_num, 34 | "lambda_": args.lambda_, 35 | "burn_in": args.burn_in, 36 | "num_samples_posterior": args.num_samples_posterior, 37 | "leapfrog_steps": args.leapfrog_steps, 38 | "hmc_learning_rate": args.hmc_learning_rate, 39 | "hmc_opt_accept": args.hmc_opt_accept, 40 | "flag_adapt": args.flag_adapt 41 | }, 42 | 43 | name=f"PresGAN_{args.dataset}" 44 | 45 | ) 46 | model = PresGAN(imgSize=input_size, channels=channels, args=args) 47 | model.train_model(train_dataloader) 48 | wandb.finish() 49 | 50 | elif args.sample: 51 | _, input_size, channels = pick_dataset(dataset_name = args.dataset, batch_size=args.batch_size, normalize = True, size = size) 52 | model = PresGAN(imgSize=input_size, channels=channels, args=args) 53 | model.load_checkpoints(generator_checkpoint=args.checkpoint, discriminator_checkpoint=args.discriminator_checkpoint, sigma_checkpoint=args.sigma_checkpoint) 54 | model.sample(num_samples=args.num_gen_images) 55 | 56 | elif args.outlier_detection: 57 | in_loader, input_size, channels = pick_dataset(dataset_name = args.dataset, batch_size=args.batch_size, normalize = True, size = size, mode="val") 58 | out_loader, _, _ = pick_dataset(dataset_name = args.out_dataset, batch_size=args.batch_size, normalize = True, size = input_size, mode="val") 59 | model = PresGAN(imgSize=input_size, channels=channels, args=args) 60 | model.load_checkpoints(generator_checkpoint=args.checkpoint, discriminator_checkpoint=args.discriminator_checkpoint, sigma_checkpoint=args.sigma_checkpoint) 61 | model.outlier_detection(in_loader, out_loader) 62 | 63 | else: 64 | raise Exception("Invalid mode. Set the --train or --sample flag") -------------------------------------------------------------------------------- /docs/SECURITY.md: -------------------------------------------------------------------------------- 1 | # Security Policy 2 | 3 | ## Guidelines 4 | 5 | - Never store credentials as code/config. See [Loading environment variables](loading_environment_variables.md) 6 | - Passwords in your publicly available code can easily get into the wrong hands, which is why it's best to 7 | avoid putting credentials into your repository in the first place 8 | - [Keycloak](https://www.keycloak.org/) is used for identity and access management in modern applications and services. 9 | - [Git-secrets](https://github.com/awslabs/git-secrets) statically analyzes your commits via a pre-commit git kook to 10 | ensure you're not pushing any passwords or sensitive information into your Bitbucket repository. 11 | Commits are rejected if the tool matches any of the configured regular expression patterns that indicate that sensitive 12 | information has been stored improperly. 13 | - Remove sensitive data from your files and Bitbucket history. 14 | - If you commit sensitive data, such as a password or SSH key into a git repository, you can remove it from the history. 15 | To entirely remove unwanted files from a repository's history you can use either the `git filter-repo` tool or the 16 | [BFG Repo-Cleaner](https://rtyley.github.io/bfg-repo-cleaner/) open source tool. 17 | - Access Control 18 | - Never let Bitbucket users share accounts/passwords 19 | - Make sure you diligently revoke access from Bitbucket users who are no longer working with you 20 | - Report vulnerabilities 21 | - Run [Safety](https://github.com/pyupio/safety) and [Bandit](https://bandit.readthedocs.io/en/latest/) to find new 22 | vulnerabilities. [Trivy](https://github.com/aquasecurity/trivy) scans vulnerabilities in container images. 23 | 24 | ## Supported Versions 25 | 26 | This project is under active development, and we do our best to support the latest versions. 27 | 28 | | Version | Supported | 29 | | ------- | ------------------ | 30 | | latest | 1.3.0 | 31 | 32 | ## Reporting a Vulnerability or Security Issues 33 | 34 | > Do not open issues that might have security implications! 35 | > It is critical that security related issues are reported privately so we have time to address them before they become 36 | > public knowledge. 37 | 38 | Vulnerabilities can be reported by emailing core members: 39 | 40 | - None \[none@none.pt\](mailto:none@none.pt) 41 | 42 | Please include the requested information listed below (as much as you can provide) to help us better understand the 43 | nature and scope of the possible issue: 44 | 45 | - Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 46 | - Full paths of source file(s) related to the manifestation of the issue 47 | - The location of the affected source code (tag/branch/commit or direct URL) 48 | - Any special configuration required to reproduce the issue 49 | - Environment (e.g. Linux / Windows / macOS) 50 | - Step-by-step instructions to reproduce the issue 51 | - Proof-of-concept or exploit code (if possible) 52 | - Impact of the issue, including how an attacker might exploit the issue 53 | 54 | This information will help us triage your report more quickly. 55 | 56 | ## Preferred Languages 57 | 58 | We prefer all communications to be in English. 59 | 60 | ## References 61 | 62 | - [5 tips to keep your code secure](https://bitbucket.org/blog/keep-your-code-secure) 63 | - [Removing sensitive data from a repository](https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/removing-sensitive-data-from-a-repository) 64 | -------------------------------------------------------------------------------- /docs/PixelCNN.md: -------------------------------------------------------------------------------- 1 | # Pixel Convolutional Neural Network (PixelCNN) 2 | 3 | 4 | ## Parameters 5 | 6 | | Argument | Description | Default | Choices | 7 | |--------------------------|-----------------------------------------|-----------------|------------------------------------------------------------------------------------------| 8 | | `--train` | Train model | `False` | | 9 | | `--sample` | Sample from model | `False` | | 10 | | `--outlier_detection` | Outlier detection | `False` | | 11 | | `--dataset` | Dataset name | `mnist` | `mnist`, `cifar10`, `fashionmnist`, `chestmnist`, `octmnist`, `tissuemnist`, `pneumoniamnist`, `svhn`, `tinyimagenet`, `cifar100`, `places365`, `dtd`, `imagenet` | 12 | | `--no_wandb` | Disable Wandb | `False` | | 13 | | `--out_dataset` | Outlier dataset name | `fashionmnist`| `mnist`, `cifar10`, `fashionmnist`, `chestmnist`, `octmnist`, `tissuemnist`, `pneumoniamnist`, `svhn`, `tinyimagenet`, `cifar100`, `places365`, `dtd`,`imagenet` | 14 | | `--batch_size` | Batch size | `128` | | 15 | | `--n_epochs` | Number of epochs | `100` | | 16 | | `--lr` | Learning rate | `1e-3` | | 17 | | `--gamma` | Gamma for the lr scheduler | `0.99` | | 18 | | `--sample_and_save_freq` | Sample and save frequency | `5` | | 19 | | `--hidden_channels` | Number of channels for the convolutional layers | `64` | | 20 | | `--checkpoint` | Checkpoint path | `None` | | 21 | | `--num_workers` | Number of workers for Dataloader | `0` | | 22 | 23 | 24 | ## Training 25 | 26 | The PixelCNN can be trained with: 27 | 28 | python P-CNN.py --train 29 | 30 | ## Sampling 31 | 32 | For sampling you must provide the checkpoint: 33 | 34 | python P-CNN.py --sample --checkpoint ./../../models/PixelCNN/PixelCNN_mnist.pt 35 | 36 | ## Outlier Detection 37 | 38 | Outlier Detection is performed by using the NLL scores generated by the model: 39 | 40 | python P-CNN.py --outlier_detection --checkpoint ./../../models/PixelCNN/PixelCNN_mnist.pt -------------------------------------------------------------------------------- /docs/ConditionalGAN.md: -------------------------------------------------------------------------------- 1 | # Conditional Generative Adversarial Network (cGAN) 2 | 3 | A Conditional Generative Adversarial Network (cGAN) is an extension of the traditional GAN framework where additional conditioning information, typically in the form of class labels or embeddings, is provided to both the generator and the discriminator. This allows for the generation of samples conditioned on specific attributes or classes, enhancing control over the generated outputs. 4 | 5 | ## Parameters 6 | 7 | | Parameter | Description | Default | Choices | 8 | |-------------------|------------------------------------|---------|--------------------------------------------------------------| 9 | | `--train` | train model | `False` | | 10 | | `--sample` | sample from model | `False` | | 11 | | `--dataset` | Dataset name | `mnist` | `mnist`, `cifar10`, `fashionmnist`, `chestmnist`, `octmnist`, `tissuemnist`, `pneumoniamnist`, `svhn`, `tinyimagenet`, `cifar100`, `places365`, `dtd`, `imagenet` | 12 | | `--no_wandb` | Disable Wandb | `False` | | 13 | | `--batch_size` | batch size | `128` | | 14 | | `--n_epochs` | number of epochs | `100` | | 15 | | `--lr` | learning rate | `0.0002`| | 16 | | `--beta1` | beta1 | `0.5` | | 17 | | `--beta2` | beta2 | `0.999` | | 18 | | `--latent_dim` | latent dimension | `100` | | 19 | | `--n_classes` | number of classes | `10` | | 20 | | `--img_size` | image size | `32` | | 21 | | `--channels` | channels | `1` | | 22 | | `--sample_and_save_freq` | sample interval | `5` | | 23 | | `--checkpoint` | checkpoint path | `None` | | 24 | | `--n_samples` | number of samples | `9` | | 25 | | `--d` | number of initial filters | `128` | | 26 | | `--num_workers` | Number of workers for Dataloader | `0` | | 27 | 28 | You can find out more about the parameters by checking [`util.py`](./../src/generativezoo/utils/util.py) or by running the following command on the example script: 29 | 30 | python CondGAN.py --help 31 | 32 | ## Training 33 | 34 | The training process is similar to the one mentioned in [`VanillaGAN.md`](VanillaGAN.md), but with the inclusion of the aforementioned embeddings. 35 | 36 | python CondGAN.py --train --dataset mnist --n_classes 10 37 | 38 | ## Sampling 39 | 40 | The sampling process is also similar but it requires to also include the class-related embedding and not only a noisy latent sample: 41 | 42 | python CondGAN.py --train --dataset mnist --n_classes 10 --checkpoint ./../../models/ConditionalGAN/CondGAN_mnist.pt -------------------------------------------------------------------------------- /docs/ConditionalVAE.md: -------------------------------------------------------------------------------- 1 | # Conditional Variational Autoencoder (Conditional VAE) 2 | 3 | The Conditional Variational Autoencoder (Conditional VAE) is an extension of the Vanilla VAE that incorporates additional conditional information during the training and generation process, in this case using a class label. 4 | 5 | ### Parameters 6 | 7 | | Parameter | Description | Default | Choices | 8 | |-----------------|---------------------------------------|---------|----------------------------------------------------| 9 | | `--train` | Train model | `False` | | 10 | | `--sample` | Sample model | `False` | | 11 | | `--dataset` | Dataset name | `mnist` | `mnist`, `cifar10`, `fashionmnist`, `chestmnist`, `octmnist`, `tissuemnist`, `pneumoniamnist`, `svhn`, `tinyimagenet`, `cifar100`, `places365`, `dtd`, `imagenet` | 12 | | `--no_wandb` | Disable Wandb | `False` | | 13 | | `--batch_size` | Batch size | `128` | | 14 | | `--n_epochs` | Number of epochs | `100` | | 15 | | `--lr` | Learning rate | `0.0002`| | 16 | | `--latent_dim` | Latent dimension | `128` | | 17 | | `--hidden_dims` | Hidden dimensions | `None` | | 18 | | `--checkpoint` | Checkpoint path | `None` | | 19 | | `--num_samples` | Number of samples | `16` | | 20 | | `--n_classes` | Number of classes on dataset | `10` | | 21 | | `--sample_and_save_frequency`| sample and save frequency | `5` | | 22 | | `--loss_type` | Type of loss to evaluate reconstruction | `mse` | `mse`, `ssim` | 23 | | `--kld_weight` | KL-Divergence weight | `1e-4` | | 24 | | `--num_workers` | Number of workers for Dataloader | `0` | | 25 | 26 | You can find out more about the parameters by checking [`util.py`](./../src/generativezoo/utils/util.py) or by running the following command on the example script: 27 | 28 | python CondVAE.py --help 29 | 30 | ## Training 31 | 32 | The training process for the Conditional VAE is similar to the one described in [`VanillaVAE.md`](VanillaVAE.md). Both models aim to maximize the evidence lower bound (ELBO) by minimizing the reconstruction loss and the KL divergence between the estimated latent distribution and the prior distribution. The reconstruction loss measures the difference between the generated output and the original input, while the KL divergence encourages the latent distribution to match the prior distribution. 33 | 34 | To train a model on the MNIST dataset, you can run the provided example script: 35 | 36 | python CondVAE.py --train --dataset mnist --n_classes 10 37 | 38 | ## Sampling 39 | 40 | Sampling from the Conditional VAE is similar to the sampling process of a Vanilla VAE, but class information is added. 41 | 42 | 1. Sample a point from the latent space. This can be done by sampling from a prior distribution, typically a Gaussian distribution with mean 0 and variance 1. Pick a class and represent it in the required embedding format. 43 | 44 | 2. Pass the sampled point and the embedding through the decoder network to generate a new data point of the given class. 45 | 46 | You can sample from the model you trained on MNIST by running: 47 | 48 | python VanVAE.py --sample --dataset mnist --checkpoint ./../../models/ConditionalVAE/CondVAE_mnist.pt 49 | -------------------------------------------------------------------------------- /docs/DiffusionAE.md: -------------------------------------------------------------------------------- 1 | # Diffusion Autoencoder (DiffAE) 2 | 3 | The Diffusion Autoencoder (DiffAE) is a model that learns to encode images into a latent space using an encoder and then utilizes this latent representation to guide the image generation process through a diffusion network. By jointly training the encoder and diffusion network, the diffusion autoencoder achieves effective latent space representation learning and image generation, facilitating tasks such as image reconstruction and image manipulation. 4 | 5 | ## Parameters 6 | 7 | | Parameter | Description | Default | Choices | 8 | |-----------------------|----------------------------------------|---------|--------------------------------------------------------------| 9 | | `--train` | train model | `False` | | 10 | | `--manipulate` | manipulate latents | `False` | | 11 | | `--batch_size` | batch size | `16` | | 12 | | `--n_epochs` | number of epochs | `100` | | 13 | | `--lr` | learning rate | `0.001` | | 14 | | `--timesteps` | number of timesteps | `1000` | | 15 | | `--sample_timesteps` | number of timesteps for sampling | `100` | | 16 | | `--dataset` | Dataset name | `mnist` | `mnist`, `cifar10`, `fashionmnist`, `chestmnist`, `octmnist`, `tissuemnist`, `pneumoniamnist`, `svhn`, `tinyimagenet`, `cifar100`, `places365`, `dtd`, `imagenet` | 17 | | `--no_wandb` | Disable Wandb | `False` | | 18 | | `--checkpoint` | checkpoint path | `None` | | 19 | | `--embedding_dim` | embedding dimension | `512` | | 20 | | `--model_channels` | model channels | `[64, 128, 256]` | | 21 | | `--attention_levels` | attention levels | `[False, True, True]` | | 22 | | `--num_res_blocks` | number of res blocks | `1` | | 23 | | `--sample_and_save_freq` | sample and save frequency | `10` | | 24 | | `--num_workers` | Number of workers for Dataloader | `0` | | 25 | 26 | You can find out more about the parameters by checking [`util.py`](./../src/generativezoo/utils/util.py) or by running the following command on the example script: 27 | 28 | python DAE.py --help 29 | 30 | ## Training 31 | 32 | During training, the diffusion autoencoder leverages the noise prediction capability of the diffusion network to simultaneously train both the encoder and diffusion network. This training process presents opportunities for future optimizations by modifying the training objective to include latent space representation losses. 33 | 34 | python DAE.py --train --dataset pneumoniamnist 35 | 36 | ## Manipulate Images 37 | 38 | While direct sampling from the model may not be feasible, an alternative approach involves training binary classifiers on the embeddings produced by the encoders, using class or feature labels. These classifiers can then be leveraged to manipulate the latent space, enabling the control of specific features within generated images. This technique allows for targeted manipulation of image features. 39 | 40 | python DAE.py --manipulate --dataset pneumoniamnist --checkpoint ./../../models/DiffusionAE/DiffAE_pneumoniamnist.pt -------------------------------------------------------------------------------- /docs/RealNVP.md: -------------------------------------------------------------------------------- 1 | # RealNVP 2 | 3 | This work implements the Real-valued Non-Volume Preserving (RealNVP) transformations, a set of powerful invertible and learnable transformations, resulting in an unsupervised learning algorithm with exact log-likelihood computation, exact sampling, exact inference of latent variables, and an interpretable latent space. 4 | 5 | ## Parameters 6 | 7 | | Argument | Description | Default | Choices | 8 | |-----------------------------|---------------------------------------------------|-------------|---------------------------------------------------------------| 9 | | `--train` | Train model | `False` | | 10 | | `--sample` | Sample model | `False` | | 11 | | `--outlier_detection` | Outlier detection | `False` | | 12 | | `--dataset` | Dataset name | `mnist` | `mnist`, `cifar10`, `fashionmnist`, `chestmnist`, `octmnist`, `tissuemnist`, `pneumoniamnist`, `svhn`, `tinyimagenet`, `cifar100`, `places365`, `dtd`, `imagenet` | 13 | | `--no_wandb` | Disable Wandb | `False` | | 14 | | `--out_dataset` | Outlier dataset name | `fashionmnist` | `mnist`, `cifar10`, `fashionmnist`, `chestmnist`, `octmnist`, `tissuemnist`, `pneumoniamnist`, `svhn`, `tinyimagenet`, `cifar100`, `places365`, `dtd`, `imagenet` | 15 | | `--batch_size` | Batch size | `128` | | 16 | | `--n_epochs` | Number of epochs | `100` | | 17 | | `--lr` | Learning rate | `1e-3` | | 18 | | `--weight_decay` | Weight decay | `1e-5` | | 19 | | `--max_grad_norm` | Max grad norm | `100.0` | | 20 | | `--sample_and_save_freq` | Sample and save frequency | `5` | | 21 | | `--num_scales` | Number of scales | `2` | | 22 | | `--mid_channels` | Mid channels | `64` | | 23 | | `--num_blocks` | Number of blocks | `8` | | 24 | | `--checkpoint` | Checkpoint path | `None` | | 25 | | `--num_workers` | Number of workers for Dataloader | `0` | | 26 | 27 | You can find out more about the parameters by checking [`util.py`](./../src/generativezoo/utils/util.py) or by running the following command on the example script: 28 | 29 | python RNVP.py --help 30 | 31 | ## Training 32 | 33 | You can train this model with the following command: 34 | 35 | python RNVP.py --train --dataset octmnist 36 | 37 | ## Sampling 38 | 39 | To sample, please provide the checkpoint: 40 | 41 | python RNVP.py --sample --dataset octmnist --checkpoint ./../../models/RealNVP/RealNVP_octmnist.pt 42 | 43 | ## Outlier Detection 44 | 45 | Outlier Detection is performed by using the NLL scores generated by the model: 46 | 47 | python RNVP.py --outlier_detection --dataset octmnist --out_dataset mnist --checkpoint ./../../models/RealNVP/RealNVP_octmnist.pt -------------------------------------------------------------------------------- /docs/WassersteinGAN.md: -------------------------------------------------------------------------------- 1 | # Wasserstein GAN with Gradient Penalty (WGAN-GP) 2 | 3 | WGAN-GP presents an alternative to clipping weights in typical WGANs: penalize the norm of gradient of the critic with respect to its input. This method performs better than standard WGAN and enables stable training of a wide variety of GAN architectures. 4 | 5 | ## Parameters 6 | 7 | | Argument | Description | Default | Choices | 8 | |---------------------------|----------------------------------------|----------|-------------------------------------------------------------------------| 9 | | `--train` | Train model | `False` | | 10 | | `--sample` | Sample from model | `False` | | 11 | | `--outlier_detection` | Outlier detection | `False` | | 12 | | `--dataset` | Dataset name | `mnist` | `mnist`, `cifar10`, `fashionmnist`, `chestmnist`, `octmnist`, `tissuemnist`, `pneumoniamnist`, `svhn`, `tinyimagenet`, `cifar100`, `places365`, `dtd`, `imagenet` | 13 | | `--no_wandb` | Disable Wandb | `False` | | 14 | | `--out_dataset` | Outlier dataset name | `fashionmnist` | `mnist`, `cifar10`, `fashionmnist`, `chestmnist`, `octmnist`, `tissuemnist`, `pneumoniamnist`, `svhn`, `tinyimagenet`, `cifar100`, `places365`, `dtd`, `imagenet` | 15 | | `--batch_size` | Batch size | `256` | | 16 | | `--n_epochs` | Number of epochs | `100` | | 17 | | `--latent_dim` | Latent dimension | `100` | | 18 | | `--d` | D | `64` | | 19 | | `--lrg` | Learning rate generator | `0.0002` | | 20 | | `--lrd` | Learning rate discriminator | `0.0002` | | 21 | | `--beta1` | Beta1 | `0.5` | | 22 | | `--beta2` | Beta2 | `0.999` | | 23 | | `--sample_and_save_freq` | Sample interval | `5` | | 24 | | `--checkpoint` | Checkpoint path | `None` | | 25 | | `--discriminator_checkpoint` | Discriminator checkpoint path | `None` | | 26 | | `--gp_weight` | Gradient penalty weight | `10.0` | | 27 | | `--n_critic` | Number of critic updates per generator update | `5` | | 28 | | `--n_samples` | Number of samples | `9` | | 29 | | `--num_workers` | Number of workers for Dataloader | `0` | | 30 | 31 | You can find out more about the parameters by checking [`util.py`](./../src/generativezoo/utils/util.py) or by running the following command on the example script: 32 | 33 | python WGAN.py --help 34 | 35 | ## Training 36 | 37 | The training command is similar to the one found in other GANs present in the zoo: 38 | 39 | python WGAN.py --train --dataset cifar10 --latent_dim 1024 40 | 41 | ## Sampling 42 | 43 | Sampling is also close to what is done in other adversarial networks: 44 | 45 | pythow WGAn.py --sample --dataset cifar10 --latent_dim 1024 --checkpoint ./../../models/WassersteinGAN/WGAN_cifar10.py 46 | -------------------------------------------------------------------------------- /docs/VanillaFlow.md: -------------------------------------------------------------------------------- 1 | # Vanilla Flow 2 | 3 | A Normalizing Flow model is designed to learn the underlying probability distribution of a given dataset. In a Normalizing Flow model, the input data is transformed through a series of invertible transformations, also known as flow steps. Each flow step consists of a deterministic function that maps the input data to a new representation, and its inverse function that maps the transformed data back to the original space. The key idea behind Normalizing Flow models is that by applying a sequence of invertible transformations, the model can learn a more complex and flexible distribution. This allows the model to capture intricate patterns and dependencies in the data. During training, the parameters of the flow steps are optimized to minimize the difference between the learned distribution and the true distribution of the data. This is typically done by maximizing the likelihood of the training data. 4 | 5 | ## Parameters 6 | 7 | | Argument | Description | Default | Choices | 8 | |-----------------------------|---------------------------------------------------|-------------|---------------------------------------------------------------| 9 | | `--train` | Train model | `False` | | 10 | | `--sample` | Sample model | `False` | | 11 | | `--dataset` | Dataset name | `mnist` | `mnist`, `cifar10`, `fashionmnist`, `chestmnist`, `octmnist`, `tissuemnist`, `pneumoniamnist`, `svhn`, `tinyimagenet`, `cifar100`, `places365`, `dtd`, `imagenet` | 12 | | `--no_wandb` | Disable Wandb | `False` | | 13 | | `--batch_size` | Batch size | `128` | | 14 | | `--n_epochs` | Number of epochs | `100` | | 15 | | `--lr` | Learning rate | `1e-3` | | 16 | | `--c_hidden` | Hidden units in the first coupling layer | `16` | | 17 | | `--multi_scale` | Use multi scale | `False` | | 18 | | `--vardeq` | Use variational dequantization | `False` | | 19 | | `--sample_and_save_freq` | Sample and save frequency | `5` | | 20 | | `--checkpoint` | Checkpoint path | `None` | | 21 | | `--outlier_detection` | Outlier detection | `False` | | 22 | | `--out_dataset` | Outlier dataset name | `fashionmnist` | `mnist`, `cifar10`, `cifar100`, `places365`, `dtd`, `fashionmnist`, `chestmnist`, `pneumoniamnist`, `tissuemnist`, `pneumoniamnist`, `svhn`,`tinyimagenet`, `imagenet` | 23 | | `--n_layers` | Number of layers | `8` | | 24 | | `--num_workers` | Number of workers for Dataloader | `0` | | 25 | 26 | You can find out more about the parameters by checking [`util.py`](./../src/generativezoo/utils/util.py) or by running the following command on the example script: 27 | 28 | python VanFlow.py --help 29 | 30 | ## Training 31 | 32 | You can train this model with the following command: 33 | 34 | python VanFlow.py --train --dataset pneumoniamnist 35 | 36 | ## Sampling 37 | 38 | To sample, please provide the checkpoint: 39 | 40 | python VanFlow.py --sample --dataset pneumoniamnist --checkpoint ./../../models/VanillaFlow/VanFlow_pneumoniamnist.pt 41 | 42 | ## Outlier Detection 43 | 44 | Outlier Detection is performed by using the NLL scores generated by the model: 45 | 46 | python VanFlow.py --outlier_detection --dataset pneumoniamnist --out_dataset mnist --checkpoint ./../../models/VanillaFlow/VanFlow_pneumoniamnist.pt -------------------------------------------------------------------------------- /docs/VanillaVAE.md: -------------------------------------------------------------------------------- 1 | # Variational Autoencoder (Vanilla VAE) 2 | 3 | The Vanilla VAE (Variational Autoencoder) is a generative model that learns to encode and decode data. It is commonly used for unsupervised learning tasks such as dimensionality reduction and data generation. 4 | 5 | ### Parameters 6 | 7 | | Parameter | Description | Default | Choices | 8 | |-----------------|---------------------------------------|---------|----------------------------------------------------| 9 | | `--train` | Train model | `False` | | 10 | | `--sample` | Sample model | `False` | | 11 | | `--outlier_detection` | Out-of-distribution detection | `False` | | 12 | | `--dataset` | Dataset name | `mnist` | `mnist`, `cifar10`, `fashionmnist`, `chestmnist`, `octmnist`, `tissuemnist`, `pneumoniamnist`, `svhn`, `tinyimagenet`, `cifar100`, `places365`, `dtd`, `imagenet` | 13 | | `--no_wandb` | Disable Wandb | `False` | | 14 | | `--out_dataset` | Outlier dataset name |`mnist`| `mnist`, `cifar10`, `fashionmnist`, `chestmnist`, `octmnist`, `tissuemnist`, `pneumoniamnist`, `svhn`, `tinyimagenet`, `cifar100`, `places365`, `dtd`, `imagenet` | 15 | | `--batch_size` | Batch size | `128` | | 16 | | `--n_epochs` | Number of epochs | `100` | | 17 | | `--lr` | Learning rate | `0.0002`| | 18 | | `--latent_dim` | Latent dimension | `128` | | 19 | | `--hidden_dims` | Hidden dimensions | `None` | | 20 | | `--checkpoint` | Checkpoint path | `None` | | 21 | | `--num_samples` | Number of samples | `16` | | 22 | | `--sample_and_save_frequency`| sample and save frequency| `5` | | 23 | | `--loss_type` | Type of loss to evaluate reconstruction | `mse` | `mse`, `ssim` | 24 | | `--kld_weight` | KL-Divergence weight | `1e-4` | | 25 | | `--num_workers` | Number of workers for Dataloader | `0` | | 26 | 27 | You can find out more about the parameters by checking [`util.py`](./../src/generativezoo/utils/util.py) or by running the following command on the example script: 28 | 29 | python VanVAE.py --help 30 | 31 | ## Training 32 | 33 | The Vanilla VAE is trained using a combination of a reconstruction loss and a regularization loss. The reconstruction loss measures the difference between the original data point and its reconstruction. The regularization loss encourages the latent space distribution to follow a prior distribution, typically a Gaussian distribution. 34 | 35 | During training, the model learns to minimize the combined loss by adjusting the parameters of the encoder and decoder networks using techniques such as backpropagation and gradient descent. 36 | 37 | To train a model on the FashionMNIST dataset, you can run the provided example script: 38 | 39 | python VanVAE.py --train --dataset fashionmnist 40 | 41 | ## Sampling 42 | 43 | Sampling from the Vanilla VAE model allows us to generate new data points based on the learned representations. To sample from the model, we can follow these steps: 44 | 45 | 1. Sample a point from the latent space. This can be done by sampling from a prior distribution, typically a Gaussian distribution with mean 0 and variance 1. 46 | 47 | 2. Pass the sampled point through the decoder network to generate a new data point. 48 | 49 | You can sample from the model you trained on FashionMNIST by running: 50 | 51 | python VanVAE.py --sample --dataset fashionmnist --checkpoint ./../../models/VanillaVAE/VanVAE_fashionmnist.pt 52 | 53 | ## Outlier Detection 54 | 55 | To detect out-of-distribution samples, we use the loss function as a way to produce an anomaly score. An in-distribution sample should have a low anomaly score, i.e., should be properly reconstructed and its latent space should approximate the prior. On the other hand, an out-of-distribution sample should have a high loss because it is poorly reconstructed and the encoded features do not follow a normal distribution. 56 | 57 | python VanVAE.py --outlier_detection --dataset fashionmnist --out_dataset mnist --checkpoint ./../../models/VanillaVAE/VanVAE_fashionmnist.pt 58 | -------------------------------------------------------------------------------- /docs/Glow.md: -------------------------------------------------------------------------------- 1 | # Glow 2 | 3 | Glow is a simple type of generative flow using an invertible 1x1 convolution. Although it is a generative model optimized towards the plain log-likelihood objective, it is capable of efficient realistic-looking synthesis and manipulation of large images. 4 | 5 | ## Parameters 6 | 7 | | Argument | Description | Default | Choices | 8 | |----------------------|---------------------------------------|---------|------------------------------------------------------| 9 | | `--train` | Train model | `False` | | 10 | | `--sample` | Sample from model | `False` | | 11 | | `--outlier_detection`| Outlier detection | `False` | | 12 | | `--dataset` | Dataset name | `mnist` | `mnist`, `cifar10`, `fashionmnist`, `chestmnist`, `octmnist`, `tissuemnist`, `pneumoniamnist`, `svhn`, `tinyimagenet`, `cifar100`, `places365`, `dtd`, `imagenet` | 13 | | `--no_wandb` | Disable Wandb | `False` | | 14 | | `--out_dataset` | Outlier dataset name | `fashionmnist` | `mnist`, `cifar10`, `fashionmnist`, `chestmnist`, `octmnist`, `tissuemnist`, `pneumoniamnist`, `svhn`, `tinyimagenet`, `cifar100`, `places365`, `dtd`, `imagenet` | 15 | | `--batch_size` | Batch size | `128` | | 16 | | `--n_epochs` | Number of epochs | `100` | | 17 | | `--lr` | Learning rate | `0.0002`| | 18 | | `--hidden_channels` | Hidden channels | `64` | | 19 | | `--K` | Number of layers per block | `8` | | 20 | | `--L` | Number of blocks | `3` | | 21 | | `--actnorm_scale` | Act norm scale | `1.0` | | 22 | | `--flow_permutation` | Flow permutation |`invconv`| `invconv`, `shuffle`, `reverse` | 23 | | `--flow_coupling` | Flow coupling |`affine` | `additive`, `affine` | 24 | | `--LU_decomposed` | Train with LU decomposed 1x1 convs |`False` | | 25 | | `--learn_top` | Learn top layer (prior) | `False` | | 26 | | `--y_condition` | Class Conditioned Glow | `False` | | 27 | | `--y_weight` | Weight of class condition | `0.01` | | 28 | | `--num_classes` | Number of classes | `10` | | 29 | | `--sample_and_save_freq` | Sample and save frequency | `5` | | 30 | | `--checkpoint` | Checkpoint path | `None` | | 31 | | `--n_bits` | Number of bits | `8` | | 32 | | `--max_grad_clip` | Max Grad clip | `0.0` | | 33 | | `--max_grad_norm` | Max Grad Norm | `0.0` | | 34 | | `--num_workers` | Number of workers for Dataloader | `0` | | 35 | | `--warmup` | Number of warmup epochs | `10` | | 36 | | `--decay` | weight decay of learning rate | `0` | | 37 | 38 | You can find out more about the parameters by checking [`util.py`](./../src/generativezoo/utils/util.py) or by running the following command on the example script: 39 | 40 | python GLOW.py --help 41 | 42 | ## Training 43 | 44 | You can train this model with the following command: 45 | 46 | python GLOW.py --train --dataset octmnist 47 | 48 | ## Sampling 49 | 50 | To sample, please provide the checkpoint: 51 | 52 | python GLOW.py --sample --dataset octmnist --checkpoint ./../../models/Glow/Glow_octmnist.pt 53 | 54 | ## Outlier Detection 55 | 56 | Outlier Detection is performed by using the NLL scores generated by the model: 57 | 58 | python GLOW.py --outlier_detection --dataset octmnist --out_dataset mnist --checkpoint ./../../models/Glow/Glow_octmnist.pt 59 | -------------------------------------------------------------------------------- /docs/FlowPlusPlus.md: -------------------------------------------------------------------------------- 1 | # Flow++ 2 | 3 | Flow++ is a generative model that aims to learn the underlying probability distribution of a given dataset. This work improves upon three limiting design choices employed by flow-based models in prior work: the use of uniform noise for dequantization, the use of inexpressive affine flows, and the use of purely convolutional conditioning networks in coupling layers. 4 | 5 | ## Parameters 6 | 7 | | Argument | Description | Default | Choices | 8 | |---------------------------|------------------------------------------|-----------------|------------------------------------------------------------------------------------------| 9 | | `--train` | Train model | `False` | | 10 | | `--sample` | Sample from model | `False` | | 11 | | `--outlier_detection` | Outlier detection | `False` | | 12 | | `--dataset` | Dataset name | `mnist` | `mnist`, `cifar10`, `fashionmnist`, `chestmnist`, `octmnist`, `tissuemnist`, `pneumoniamnist`, `svhn`, `tinyimagenet`, `cifar100`, `places365`, `dtd`, `imagenet` | 13 | | `--no_wandb` | Disable Wandb | `False` | | 14 | | `--out_dataset` | Outlier dataset name | `fashionmnist`| `mnist`, `cifar10`, `fashionmnist`, `chestmnist`, `octmnist`, `tissuemnist`, `pneumoniamnist`, `svhn`, `tinyimagenet`, `cifar100`, `places365`, `dtd`, `imagenet` | 15 | | `--batch_size` | Batch size | `8` | | 16 | | `--n_epochs` | Number of epochs | `100` | | 17 | | `--lr` | Learning rate | `1e-3` | | 18 | | `--warm_up` | Warm up | `200` | | 19 | | `--grad_clip` | Gradient clip | `1.0` | | 20 | | `--drop_prob` | Dropout probability | `0.2` | | 21 | | `--num_blocks` | Number of blocks | `10` | | 22 | | `--num_components` | Number of components in the mixture | `32` | | 23 | | `--num_dequant_blocks` | Number of blocks in dequantization | `2` | | 24 | | `--num_channels` | Number of channels in Flow++ | `96` | | 25 | | `--use_attn` | Use attention | `False` | | 26 | | `--sample_and_save_freq` | Sample interval | `5` | | 27 | | `--checkpoint` | Checkpoint path to VQVAE | `None` | | 28 | | `--num_workers` | Number of workers for Dataloader | `0` | | 29 | 30 | You can find out more about the parameters by checking [`util.py`](./../src/generativezoo/utils/util.py) or by running the following command on the example script: 31 | 32 | python FlowPP.py --help 33 | 34 | ## Training 35 | 36 | You can train this model with the following command: 37 | 38 | python FlowPP.py --train --dataset mnist 39 | 40 | ## Sampling 41 | 42 | To sample, please provide the checkpoint: 43 | 44 | python FlowPP.py --sample --dataset fashionmnist --checkpoint ./../../models/FlowPP/FlowPP_mnist.pt 45 | 46 | ## Outlier Detection 47 | 48 | Outlier Detection is performed by using the NLL scores generated by the model: 49 | 50 | python FlowPP.py --outlier_detection --dataset mnist --out_dataset fashionmnist --checkpoint ./../../models/FlowPP/FlowPP_mnist.pt -------------------------------------------------------------------------------- /docs/RectifiedFlows.md: -------------------------------------------------------------------------------- 1 | # Rectified Flows 2 | 3 | **This model supports `Accelerate` for Multi-GPU and Mixed Precision Training.** 4 | 5 | ## Parameters 6 | 7 | | Argument | Description | Default | Choices | 8 | |------------------------------|------------------------------------|-------------------|--------------------------------------------------------------------------------------------------------------| 9 | | `--train` | Train model | `False` | | 10 | | `--sample` | Sample model | `False` | | 11 | | `--outlier_detection` | Outlier detection | `False` | | 12 | | `--dataset` | Dataset name | `mnist` | `mnist`, `cifar10`, `fashionmnist`, `chestmnist`, `octmnist`, `tissuemnist`, `pneumoniamnist`, `svhn`, `tinyimagenet`, `cifar100`, `places365`, `dtd`, `imagenet` | 13 | | `--no_wandb` | Disable Wandb | `False` | | 14 | | `--out_dataset` | Outlier dataset name | `fashionmnist` | `mnist`, `cifar10`, `cifar100`, `places365`, `dtd`, `fashionmnist`, `chestmnist`, `octmnist`, `tissuemnist`, `pneumoniamnist`, `svhn`, `tinyimagenet`, `imagenet` | 15 | | `--batch_size` | Batch size | `128` | | 16 | | `--n_epochs` | Number of epochs | `100` | | 17 | | `--lr` | Learning rate | `5e-4` | | 18 | | `--patch_size` | Patch size | `2` | | 19 | | `--dim` | Dimension | `64` | | 20 | | `--n_layers` | Number of layers | `6` | | 21 | | `--n_heads` | Number of heads | `4` | | 22 | | `--multiple_of` | Multiple of | `256` | | 23 | | `--ffn_dim_multiplier` | FFN dim multiplier | `None` | | 24 | | `--norm_eps` | Norm eps | `1e-5` | | 25 | | `--class_dropout_prob` | Class dropout probability | `0.1` | | 26 | | `--sample_and_save_freq` | Sample and save frequency | `5` | | 27 | | `--num_classes` | Number of classes | `10` | | 28 | | `--checkpoint` | Checkpoint path | `None` | | 29 | | `--num_workers` | Number of workers for Dataloader | `0` | | 30 | | `--latent` | Use latent version | `False` | | 31 | | `--warmup` | Number of warmup epochs | `10` | | 32 | | `--decay` | Decay rate | `1e-5` | | 33 | | `--ema_rate` | Exponential moving average rate | `0.999` | | 34 | | `--conditional` | Conditional model | `False` | | 35 | | `--size` | Size of input image | `None` | | 36 | 37 | 38 | You can find out more about the parameters by checking [`util.py`](./../src/generativezoo/utils/util.py) or by running the following command on the example script: 39 | 40 | python RF.py --help 41 | 42 | ## Training 43 | 44 | You can train this model with the following command: 45 | 46 | accelerate launch RF.py --train --dataset mnist 47 | 48 | ## Sampling 49 | 50 | To sample, please provide the checkpoint: 51 | 52 | python RF.py --sample --dataset fashionmnist --checkpoint ./../../models/FlowMatching/FM_mnist.pt -------------------------------------------------------------------------------- /docs/NCSNv2.md: -------------------------------------------------------------------------------- 1 | # NCSNv2 2 | 3 | This work provides a new theoretical analysis of learning and sampling from score-based models in high dimensional spaces, explaining existing failure modes and motivating new solutions that generalize across datasets. To enhance stability, it also proposes to maintain an exponential moving average of model weights. 4 | 5 | ## Parameters 6 | 7 | | Argument | Description | Default | Choices | 8 | |-----------------------------|--------------------------------------------------|-------------|--------------------------------------------------------------| 9 | | `--train` | Train model | `False` | | 10 | | `--sample` | Sample from model | `False` | | 11 | | `--dataset` | Dataset name | `mnist` | `mnist`, `cifar10`, `fashionmnist`, `chestmnist`, `octmnist`, `tissuemnist`, `pneumoniamnist`, `svhn`, `tinyimagenet`, `cifar100`, `places365`, `dtd`, `imagenet` | 12 | | `--no_wandb` | Disable Wandb | `False` | | 13 | | `--batch_size` | Batch size | `128` | | 14 | | `--n_epochs` | Number of epochs | `100` | | 15 | | `--lr` | Learning rate | `0.0002` | | 16 | | `--nf` | Number of filters | `128` | | 17 | | `--act` | Activation | `elu` | `relu`, `elu`, `swish` | 18 | | `--centered` | Centered | `False` | | 19 | | `--sigma_min` | Min value for sigma | `0.01` | | 20 | | `--sigma_max` | Max value for sigma | `50` | | 21 | | `--num_scales` | Number of scales | `232` | | 22 | | `--normalization` | Normalization | `InstanceNorm++` | `InstanceNorm`, `GroupNorm`, `VarianceNorm`, `InstanceNorm++` | 23 | | `--num_classes` | Number of classes | `10` | | 24 | | `--ema_decay` | EMA decay | `0.999` | | 25 | | `--continuous` | Continuous | `False` | | 26 | | `--reduce_mean` | Reduce mean | `False` | | 27 | | `--likelihood_weighting` | Likelihood weighting | `False` | | 28 | | `--beta1` | Beta1 | `0.9` | | 29 | | `--beta2` | Beta2 | `0.999` | | 30 | | `--weight_decay` | Weight decay | `0.0` | | 31 | | `--warmup` | Warmup | `0` | | 32 | | `--grad_clip` | Grad clip | `-1.0` | | 33 | | `--sample_and_save_freq` | Sample and save frequency | `5` | | 34 | | `--sampler` | Sampler name | `pc` | `pc`, `ode` | 35 | | `--predictor` | Predictor | `none` | `none`, `em`, `rd`, `as` | 36 | | `--corrector` | Corrector | `ald` | `none`, `l`, `ald` | 37 | | `--snr` | Signal to noise ratio | `0.176` | | 38 | | `--n_steps` | Number of steps | `5` | | 39 | | `--probability_flow` | Probability flow | `False` | | 40 | | `--noise_removal` | Noise removal | `False` | | 41 | | `--checkpoint` | Checkpoint path to VQVAE | `None` | | 42 | | `--num_workers` | Number of workers for Dataloader | `0` | | 43 | 44 | You can find out more about the parameters by checking [`util.py`](./../src/generativezoo/utils/util.py) or by running the following command on the example script: 45 | 46 | python NCSNv2.py --help 47 | 48 | ## Training 49 | 50 | The model can be trained with: 51 | 52 | python .\NCSNv2.py --train --nf 32 --noise_removal 53 | 54 | ## Sampling 55 | 56 | For sampling you must provide the checkpoint: 57 | 58 | python .\NCSNv2.py --sample --nf 32 --noise_removal --checkpoint ./../../models/NCSNv2/NCSNv2_mnist.pt -------------------------------------------------------------------------------- /docs/DCGAN.md: -------------------------------------------------------------------------------- 1 | # Deep Convolutional Generative Adversarial Network (DC-GAN) 2 | 3 | A Generative Adversarial Network (GAN) comprises two neural networks: a **Generator** and a **Discriminator**, engaged in a minimax game. The generator fabricates synthetic images out of a noisy input, while the discriminator evaluates the authenticity of these samples, distinguishing between real data and the generated ones. Through iterative training, the generator learns to produce increasingly realistic outputs that deceive the discriminator, while the discriminator enhances its ability to differentiate genuine from fake data. 4 | 5 | ## Parameters 6 | 7 | | Argument | Description | Default | Choices | 8 | |---------------------------|----------------------------------------------------|----------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| 9 | | `--train` | Train model | `False` | | 10 | | `--sample` | Sample from model | `False` | | 11 | | `--outlier_detection` | Outlier detection | `False` | | 12 | | `--batch_size` | Batch size | `128` | | 13 | | `--dataset` | Dataset name | `mnist` | `mnist`, `cifar10`, `fashionmnist`, `chestmnist`, `octmnist`, `tissuemnist`, `pneumoniamnist`, `svhn`, `tinyimagenet`, `cifar100`, `places365`, `dtd`, `imagenet` | 14 | | `--no_wandb` | Disable Wandb | `False` | | 15 | | `--out_dataset` | Outlier dataset name | `fashionmnist` | `mnist`, `cifar10`, `fashionmnist`, `chestmnist`, `octmnist`, `tissuemnist`, `pneumoniamnist`, `svhn`, `tinyimagenet`, `cifar100`, `places365`, `dtd`, `imagenet` | 16 | | `--n_epochs` | Number of epochs | `100` | | 17 | | `--lrg` | Learning rate generator | `0.0002` | | 18 | | `--lrd` | Learning rate discriminator | `0.0002` | | 19 | | `--beta1` | Beta1 | `0.5` | | 20 | | `--beta2` | Beta2 | `0.999` | | 21 | | `--latent_dim` | Latent dimension | `100` | | 22 | | `--img_size` | Image size | `32` | | 23 | | `--channels` | Channels | `1` | | 24 | | `--sample_and_save_freq` | Sample interval | `5` | | 25 | | `--checkpoint` | Checkpoint path | `None` | | 26 | | `--discriminator_checkpoint` | Discriminator checkpoint path | `None` | | 27 | | `--n_samples` | Number of samples | `9` | | 28 | | `--d` | d | `128` | | 29 | | `--num_workers` | Number of workers for Dataloader | `0` | | 30 | 31 | You can find out more about the parameters by checking [`util.py`](./../src/generativezoo/utils/util.py) or by running the following command on the example script: 32 | 33 | python DCGAN.py --help 34 | 35 | ## Training 36 | 37 | Adversarial losses are used during training. The generator is encouraged to generate images that fool the discriminator into classifying them as real, while the discriminator is trained as a binary classifier to distinguish between real and generated images. The model can be trained using the following command: 38 | 39 | python DCGAN.py --train --dataset svhn 40 | 41 | ## Sampling 42 | 43 | To sample from a GAN, you input a noisy latent vector of a predefined size into the generator network. This latent vector serves as a random seed that the generator uses to generate synthetic data samples. 44 | 45 | python DCGAN.py --sample --dataset svhn --checkpoint ./../../models/DCGAN/DCGAN_svhn.pt 46 | 47 | ## Outlier Detection 48 | 49 | To perform outlier detection, only the Discriminator will be used: 50 | 51 | python DCGAN.py --outlier_detection --dataset svhn --out_dataset cifar10 --discriminator_checkpoint ./../../models/DCGAN/DCDisc_svhn.pt -------------------------------------------------------------------------------- /docs/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or 4 | additional documentation, we greatly value feedback and contributions. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. Any contributions should come through valid Pull/Merge. 8 | 9 | ## Reporting Bugs/Feature Requests 10 | 11 | When filing an issue, please check previous issues to make sure somebody else hasn't already reported the issue. 12 | Please try to include as much information as you can. Details like these are incredibly useful: 13 | 14 | - A reproducible test case or series of steps 15 | - The version of our code being used 16 | - Any modifications you've made relevant to the bug 17 | - Anything unusual about your environment or deployment 18 | 19 | ## Contributing via Pull Requests 20 | 21 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 22 | 23 | 1. You are working against the latest source on the **master** branch. 24 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 25 | 26 | To send a pull request, please: 27 | 28 | 1. Fork the repository. 29 | 2. Modify the source; please focus on the specific change you are contributing. 30 | If you also reformat all the code, it will be hard for us to focus on your change. 31 | 3. Ensure local tests pass by executing `pytest`. 32 | 4. Commit to your fork using clear commit messages. 33 | 5. Send us a pull request, answering any default questions in the pull request interface. 34 | 35 | ## Tips for Modifying the Source Code 36 | 37 | - We recommend developing on Linux as this is the only OS where all features are currently 100% functional. 38 | - Use **Python >= 3.7** for development. 39 | - Please try to avoid introducing additional dependencies on 3rd party packages. 40 | - We encourage you to add your own unit tests, but please ensure they run quickly (unit tests should train models on 41 | small data-subsample with the lowest values of training iterations and time-limits that suffice to evaluate the intended 42 | functionality). 43 | 44 | ## Contribution Guidelines 45 | 46 | 1. Please adhere to the [PEP-8](https://www.python.org/dev/peps/pep-0008/) standards. A maximum line length of 120 characters is 47 | allowed for consistency with the C++ code. Editors such as PyCharm [(see here)](https://www.jetbrains.com/help/pycharm/code-inspection.html) 48 | and Visual Studio Code [(see here)](https://code.visualstudio.com/docs/python/linting#_flake8) can be configured to check 49 | for PEP8 issues. 50 | 2. Code can be validated with flake8 using the configuration file in the root directory called `.flake8`. 51 | 3. Run `black` and `isort` linters before creating a PR. 52 | 4. Any changes to core functionality must pass all existing unit tests. 53 | 5. Functions/methods should have type hints for arguments and return types. 54 | 6. Additional functionality should have associated unit tests. 55 | 7. Provide documentation ([Google Docstrings format](https://www.sphinx-doc.org/en/master/usage/extensions/example_google.html)) 56 | whenever possible, even for simple functions or classes. 57 | 58 | ### Design patterns 59 | 60 | Favour the following design patterns, in this order: 61 | 62 | 1. Functional programming: using functions to do one thing well, not altering the original data. 63 | 2. Modular code: using functions to eliminate duplication, improve readability and reduce indentation. 64 | 3. Object-oriented programming: Generally avoid, unless you are customising an API (for example `DataFrame`) or defining your own API. 65 | 66 | If you are not, at least, adhering to a modular style then you have gone very wrong. 67 | You should implement unit tests for each of your functions, something which is generally more tricky for object-oriented programming. 68 | 69 | ### Programming 70 | 71 | 1. Don't compare boolean values to `True` or `False`. 72 | 2. Favour `is not condition` over `not condition is` 73 | 3. Don't compare a value to `None` (`value == None`), always favour `value is None` 74 | 4. Favour [`logging`](https://docs.python.org/3/howto/logging.html) over `print` 75 | 5. Favour using configuration files, or (faster/lazier/less good/ok) `GLOBAL_VARIABLES` near the top of your code, rather than repeated 76 | use of hard-coded variables in your code, particularly when with file path variables, but also for repeated numeric hyperparameters. 77 | 78 | ### Naming conventions 79 | 80 | 1. Functions / methods: `function`, `my_function` (snake case) 81 | 2. Variables / attributes: `variable`, `my_var` (snake case) 82 | 3Class: `Model`, `MyClass` (camel case) 83 | 3. Module / file names / directory names: `module`, `file_name.py`, `dir_name` (camel case) 84 | 4. Global\* / constants: `A_GLOBAL_VARIABLE` (screaming snake case) 85 | 5. Keep all names as short and descriptive as possible. Variable names such as `x` or `df` are highly discouraged unless they are genuinely 86 | representing abstract concepts. 87 | 6. Favour good naming conventions over helpful comments 88 | 89 | ## Guidelines for creating a good pull request 90 | 91 | 1. A PR should describe the change clearly and most importantly it should mention the motivation behind the change. 92 | 2. If the PR is fixing a performance issue, mention the improvement and how the measurement was done (for educational purposes). 93 | 3. Do not leave comments unresolved. If PR comments have been addressed without making the requested code changes, 94 | explicitly mark them resolved with an appropriate comment explaining why you're resolving it. If you intend to resolve it 95 | in a follow-up PR, create a task and mention why this comment cannot be fixed in this PR. Leaving comments unresolved 96 | sets a wrong precedent for other contributors that it's ok to ignore comments. 97 | 4. In the interest of time, discuss the PR/comments in person if it's difficult to explain in writing. Document the 98 | resolution in the PR for the educational benefit of others. Don't just mark the comment resolved saying 'based on offline 99 | discussion'. 100 | 5. Add comments, if not obvious, in the PR to help the reviewer navigate your PR faster. If this is a big change, include 101 | a short design doc (docs/ folder). 102 | 6. Unit tests are mandatory for all PRs (except when the proposed changes are already covered by existing unit tests). 103 | 7. Do not use PRs as scratch pads for development as they consume valuable build/CI cycles for every commit. Build and 104 | test your changes for at least one environment (Windows/Linux/Mac) before creating a PR. 105 | 8. Keep it small. If the feature is big, it's best to split into multiple PRs. Modulo cosmetic changes, a PR with more 106 | than 10 files is notoriously hard to review. Be kind to the reviewers. 107 | 9. Separate cosmetic changes from functional changes by making them separate PRs. 108 | 10. The PR author is responsible for merging the changes once they're approved. 109 | 11. If you co-author a PR, seek review from someone else. Do not self-approve PRs. 110 | 111 | ## GitHub flow Workflow 112 | 113 | Please follow [GitHub Flow](https://githubflow.github.io/) 114 | 115 | In the GitHub flow workflow, there are 2 different branch types: 116 | 117 | - Master: contain production-ready code that can be released. 118 | - Feature: develop new features for the upcoming releases, to be merged to master after review 119 | 120 | ## Commit message 121 | 122 | We try to follow [Conventional Commits](https://www.conventionalcommits.org) for commit messages and PR titles 123 | through [gitlint](https://jorisroovers.com/gitlint/). 124 | 125 | You then need to install the pre-commit hook like so: 126 | 127 | ``` 128 | pre-commit install --hook-type commit-msg 129 | ``` 130 | 131 | ## More details: 132 | 133 | - [4 branching workflows for Git](https://medium.com/@patrickporto/4-branching-workflows-for-git-30d0aaee7bf) 134 | - [A Note About Git Commit Messages](https://tbaggery.com/2008/04/19/a-note-about-git-commit-messages.html) 135 | -------------------------------------------------------------------------------- /docs/AdversarialVAE.md: -------------------------------------------------------------------------------- 1 | # Adversarial Variational Autoencoder (Adversarial VAE) 2 | 3 | The Adversarial Variational Autoencoder (Adversarial VAE) is a generative model that combines the power of Variational Autoencoders (VAEs) with adversarial training. VAEs are a type of deep generative model that can learn to generate new data samples by capturing the underlying distribution of the training data. Adversarial training, on the other hand, involves training a discriminator network to distinguish between real and generated samples, while simultaneously training the generator network to fool the discriminator. By combining these two techniques, the Adversarial VAE can generate high-quality samples that are both diverse and realistic. 4 | 5 | ## Parameters 6 | 7 | | Argument | Description | Default | Choices | 8 | |---------------------------|----------------------------------------------------|----------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| 9 | | `--train` | Train model | `False` | | 10 | | `--test` | Test model | `False` | | 11 | | `--sample` | Sample model | `False` | | 12 | | `--dataset` | Dataset name | `mnist` | `mnist`, `cifar10`, `fashionmnist`, `chestmnist`, `octmnist`, `tissuemnist`, `pneumoniamnist`, `svhn`, `tinyimagenet`, `cifar100`, `places365`, `dtd`, `imagenet` | 13 | | `--no_wandb` | Disable Wandb | `False` | | 14 | | `--batch_size` | Batch size | `128` | | 15 | | `--n_epochs` | Number of epochs | `100` | | 16 | | `--lr` | Learning rate | `0.0002` | | 17 | | `--latent_dim` | Latent dimension | `128` | | 18 | | `--hidden_dims` | Hidden dimensions | `None` | | 19 | | `--checkpoint` | Checkpoint path | `None` | | 20 | | `--num_samples` | Number of samples | `16` | | 21 | | `--gen_weight` | Generator weight | `0.002` | | 22 | | `--recon_weight` | Reconstruction weight | `0.002` | | 23 | | `--sample_and_save_frequency` | Sample and save frequency | `5` | | 24 | | `--outlier_detection` | Outlier detection | `False` | | 25 | | `--discriminator_checkpoint` | Discriminator checkpoint path | `None` | | 26 | | `--out_dataset` | Outlier dataset name | `fashionmnist` | `mnist`, `cifar10`, `fashionmnist`, `chestmnist`, `octmnist`, `tissuemnist`, `pneumoniamnist`, `svhn`, `tinyimagenet`, `cifar100`, `places365`, `dtd`, `imagenet` | 27 | | `--loss_type` | Type of loss to evaluate reconstruction | `mse` | `mse`, `ssim` | 28 | | `--kld_weight` | KL-Divergence weight | `1e-4` | | 29 | | `--num_workers` | Number of workers for Dataloader | `0` | | 30 | | `--size` | Size of image (None uses default for each dataset) | `None` | | 31 | 32 | You can find out more about the parameters by checking [`util.py`](./../src/generativezoo/utils/util.py) or by running the following command on the example script: 33 | 34 | python AdvVAE.py --help 35 | 36 | ## Training 37 | 38 | To train the Generator (the VAE), we do not simply try to minimize the reconstruction loss and the KL divergence. In addition to this, we incorporate two adversarial loss factors, one related with the ability to fool the discriminator with the VAE's reconstructions (adjustable with `recon_weight`) and the other related to the ability to fool the discriminator with samples generated by the VAE (adjustable with `gen_weight`). 39 | 40 | The discriminator, on the other hand, is taught to classify both the reconstructions and the generated images as false. 41 | 42 | python AdvVAE.py --train --dataset octmnist 43 | 44 | ## Testing 45 | 46 | This is related to the ability of the model to accurately reconstruct the input, which was encouraged during the training stage of the generator. 47 | 48 | python AdvVAE.py --test --dataset octmnist 49 | 50 | ## Sampling 51 | 52 | This process is similar to the one described in [`VanillaVAE.md`](VanillaVAE.md). 53 | 54 | python AdvVAE.py --sample --dataset octmnist --checkpoint ./../../models/AdversarialVAE/AdvVAE.pt 55 | 56 | ## Outlier Detection 57 | 58 | To detect out-of-distribution samples, we can either use the loss function as a way to produce an anomaly score, or the discriminator that was used for the adversarial training process. 59 | 60 | python AdvVAE.py --outlier_detection --dataset octmnist --out_dataset mnist --discriminator_checkpoint ./../../models/AdversarialVAE/Discriminator_octmnist.pt -------------------------------------------------------------------------------- /src/generativezoo/models/SM/normalization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Normalization layers.""" 17 | import torch.nn as nn 18 | import torch 19 | import functools 20 | 21 | 22 | def get_normalization(args, conditional=False): 23 | """Obtain normalization modules from the config file.""" 24 | norm = args.normalization 25 | if conditional: 26 | if norm == 'InstanceNorm++': 27 | return functools.partial(ConditionalInstanceNorm2dPlus, num_classes=args.num_classes) 28 | else: 29 | raise NotImplementedError(f'{norm} not implemented yet.') 30 | else: 31 | if norm == 'InstanceNorm': 32 | return nn.InstanceNorm2d 33 | elif norm == 'InstanceNorm++': 34 | return InstanceNorm2dPlus 35 | elif norm == 'VarianceNorm': 36 | return VarianceNorm2d 37 | elif norm == 'GroupNorm': 38 | return nn.GroupNorm 39 | else: 40 | raise ValueError('Unknown normalization: %s' % norm) 41 | 42 | 43 | class ConditionalBatchNorm2d(nn.Module): 44 | def __init__(self, num_features, num_classes, bias=True): 45 | super().__init__() 46 | self.num_features = num_features 47 | self.bias = bias 48 | self.bn = nn.BatchNorm2d(num_features, affine=False) 49 | if self.bias: 50 | self.embed = nn.Embedding(num_classes, num_features * 2) 51 | self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02) 52 | self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0 53 | else: 54 | self.embed = nn.Embedding(num_classes, num_features) 55 | self.embed.weight.data.uniform_() 56 | 57 | def forward(self, x, y): 58 | out = self.bn(x) 59 | if self.bias: 60 | gamma, beta = self.embed(y).chunk(2, dim=1) 61 | out = gamma.view(-1, self.num_features, 1, 1) * out + beta.view(-1, self.num_features, 1, 1) 62 | else: 63 | gamma = self.embed(y) 64 | out = gamma.view(-1, self.num_features, 1, 1) * out 65 | return out 66 | 67 | 68 | class ConditionalInstanceNorm2d(nn.Module): 69 | def __init__(self, num_features, num_classes, bias=True): 70 | super().__init__() 71 | self.num_features = num_features 72 | self.bias = bias 73 | self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False) 74 | if bias: 75 | self.embed = nn.Embedding(num_classes, num_features * 2) 76 | self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02) 77 | self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0 78 | else: 79 | self.embed = nn.Embedding(num_classes, num_features) 80 | self.embed.weight.data.uniform_() 81 | 82 | def forward(self, x, y): 83 | h = self.instance_norm(x) 84 | if self.bias: 85 | gamma, beta = self.embed(y).chunk(2, dim=-1) 86 | out = gamma.view(-1, self.num_features, 1, 1) * h + beta.view(-1, self.num_features, 1, 1) 87 | else: 88 | gamma = self.embed(y) 89 | out = gamma.view(-1, self.num_features, 1, 1) * h 90 | return out 91 | 92 | 93 | class ConditionalVarianceNorm2d(nn.Module): 94 | def __init__(self, num_features, num_classes, bias=False): 95 | super().__init__() 96 | self.num_features = num_features 97 | self.bias = bias 98 | self.embed = nn.Embedding(num_classes, num_features) 99 | self.embed.weight.data.normal_(1, 0.02) 100 | 101 | def forward(self, x, y): 102 | vars = torch.var(x, dim=(2, 3), keepdim=True) 103 | h = x / torch.sqrt(vars + 1e-5) 104 | 105 | gamma = self.embed(y) 106 | out = gamma.view(-1, self.num_features, 1, 1) * h 107 | return out 108 | 109 | 110 | class VarianceNorm2d(nn.Module): 111 | def __init__(self, num_features, bias=False): 112 | super().__init__() 113 | self.num_features = num_features 114 | self.bias = bias 115 | self.alpha = nn.Parameter(torch.zeros(num_features)) 116 | self.alpha.data.normal_(1, 0.02) 117 | 118 | def forward(self, x): 119 | vars = torch.var(x, dim=(2, 3), keepdim=True) 120 | h = x / torch.sqrt(vars + 1e-5) 121 | 122 | out = self.alpha.view(-1, self.num_features, 1, 1) * h 123 | return out 124 | 125 | 126 | class ConditionalNoneNorm2d(nn.Module): 127 | def __init__(self, num_features, num_classes, bias=True): 128 | super().__init__() 129 | self.num_features = num_features 130 | self.bias = bias 131 | if bias: 132 | self.embed = nn.Embedding(num_classes, num_features * 2) 133 | self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02) 134 | self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0 135 | else: 136 | self.embed = nn.Embedding(num_classes, num_features) 137 | self.embed.weight.data.uniform_() 138 | 139 | def forward(self, x, y): 140 | if self.bias: 141 | gamma, beta = self.embed(y).chunk(2, dim=-1) 142 | out = gamma.view(-1, self.num_features, 1, 1) * x + beta.view(-1, self.num_features, 1, 1) 143 | else: 144 | gamma = self.embed(y) 145 | out = gamma.view(-1, self.num_features, 1, 1) * x 146 | return out 147 | 148 | 149 | class NoneNorm2d(nn.Module): 150 | def __init__(self, num_features, bias=True): 151 | super().__init__() 152 | 153 | def forward(self, x): 154 | return x 155 | 156 | 157 | class InstanceNorm2dPlus(nn.Module): 158 | def __init__(self, num_features, bias=True): 159 | super().__init__() 160 | self.num_features = num_features 161 | self.bias = bias 162 | self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False) 163 | self.alpha = nn.Parameter(torch.zeros(num_features)) 164 | self.gamma = nn.Parameter(torch.zeros(num_features)) 165 | self.alpha.data.normal_(1, 0.02) 166 | self.gamma.data.normal_(1, 0.02) 167 | if bias: 168 | self.beta = nn.Parameter(torch.zeros(num_features)) 169 | 170 | def forward(self, x): 171 | means = torch.mean(x, dim=(2, 3)) 172 | m = torch.mean(means, dim=-1, keepdim=True) 173 | v = torch.var(means, dim=-1, keepdim=True) 174 | means = (means - m) / (torch.sqrt(v + 1e-5)) 175 | h = self.instance_norm(x) 176 | 177 | if self.bias: 178 | h = h + means[..., None, None] * self.alpha[..., None, None] 179 | out = self.gamma.view(-1, self.num_features, 1, 1) * h + self.beta.view(-1, self.num_features, 1, 1) 180 | else: 181 | h = h + means[..., None, None] * self.alpha[..., None, None] 182 | out = self.gamma.view(-1, self.num_features, 1, 1) * h 183 | return out 184 | 185 | 186 | class ConditionalInstanceNorm2dPlus(nn.Module): 187 | def __init__(self, num_features, num_classes, bias=True): 188 | super().__init__() 189 | self.num_features = num_features 190 | self.bias = bias 191 | self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False) 192 | if bias: 193 | self.embed = nn.Embedding(num_classes, num_features * 3) 194 | self.embed.weight.data[:, :2 * num_features].normal_(1, 0.02) # Initialise scale at N(1, 0.02) 195 | self.embed.weight.data[:, 2 * num_features:].zero_() # Initialise bias at 0 196 | else: 197 | self.embed = nn.Embedding(num_classes, 2 * num_features) 198 | self.embed.weight.data.normal_(1, 0.02) 199 | 200 | def forward(self, x, y): 201 | means = torch.mean(x, dim=(2, 3)) 202 | m = torch.mean(means, dim=-1, keepdim=True) 203 | v = torch.var(means, dim=-1, keepdim=True) 204 | means = (means - m) / (torch.sqrt(v + 1e-5)) 205 | h = self.instance_norm(x) 206 | 207 | if self.bias: 208 | gamma, alpha, beta = self.embed(y).chunk(3, dim=-1) 209 | h = h + means[..., None, None] * alpha[..., None, None] 210 | out = gamma.view(-1, self.num_features, 1, 1) * h + beta.view(-1, self.num_features, 1, 1) 211 | else: 212 | gamma, alpha = self.embed(y).chunk(2, dim=-1) 213 | h = h + means[..., None, None] * alpha[..., None, None] 214 | out = gamma.view(-1, self.num_features, 1, 1) * h 215 | return out 216 | -------------------------------------------------------------------------------- /docs/FlowMatching.md: -------------------------------------------------------------------------------- 1 | # Flow Matching 2 | 3 | **This model supports `Accelerate` for Multi-GPU and Mixed Precision Training.** 4 | 5 | ## Parameters 6 | 7 | | Argument | Default | Help | Choices | 8 | |------------------------------|------------------------|-----------------------------------------------|--------------------------------------------------------------------------------------------------------------| 9 | | `--train` | `False` | Train model | | 10 | | `--sample` | `False` | Sample model | | 11 | | `--batch_size` | `256` | Batch size | | 12 | | `--n_epochs` | `100` | Number of epochs | | 13 | | `--lr` | `1e-3` | Learning rate | | 14 | | `--model_channels` | `64` | Number of features | | 15 | | `--num_res_blocks` | `2` | Number of residual blocks per downsample | | 16 | | `--attention_resolutions` | `[4]` | Downsample rates for attention | | 17 | | `--dropout` | `0.0` | Dropout probability | | 18 | | `--channel_mult` | `[1, 2, 2]` | Channel multiplier for UNet levels | | 19 | | `--conv_resample` | `True` | Use learned convolutions for resampling | | 20 | | `--dims` | `2` | Signal dimensionality (1D, 2D, 3D) | | 21 | | `--num_heads` | `4` | Number of attention heads per layer | | 22 | | `--num_head_channels` | `32` | Fixed channel width per attention head | | 23 | | `--use_scale_shift_norm` | `False` | Use FiLM-like conditioning mechanism | | 24 | | `--resblock_updown` | `False` | Use residual blocks for up/downsampling | | 25 | | `--use_new_attention_order` | `False` | Use an alternative attention pattern | | 26 | | `--sample_and_save_freq` | `5` | Sample and save frequency | | 27 | | `--dataset` | `mnist` | Dataset name | `mnist`, `cifar10`, `cifar100`, `places365`, `dtd`, `fashionmnist`, `chestmnist`, `octmnist`, `tissuemnist`, `pneumoniamnist`, `svhn`, `tinyimagenet`, `imagenet` | 28 | | `--checkpoint` | `None` | Checkpoint path | | 29 | | `--num_samples` | `16` | Number of samples | | 30 | | `--out_dataset` | `fashionmnist` | Outlier dataset name | `mnist`, `cifar10`, `cifar100`, `places365`, `dtd`, `fashionmnist`, `chestmnist`, `octmnist`, `tissuemnist`, `pneumoniamnist`, `svhn`, `tinyimagenet`, `imagenet` | 31 | | `--outlier_detection` | `False` | Enable outlier detection | | 32 | | `--interpolation` | `False` | Enable interpolation | | 33 | | `--solver_lib` | `none` | Solver library | `torchdiffeq`, `zuko`, `none` | 34 | | `--step_size` | `0.1` | Step size for ODE solver | | 35 | | `--solver` | `dopri5` | Solver for ODE | `dopri5`, `rk4`, `dopri8`, `euler`, `bosh3`, `adaptive_heun`, `midpoint`, `explicit_adams`, `implicit_adams` | 36 | | `--no_wandb` | `False` | Disable Wandb logging | | 37 | | `--num_workers` | `0` | Number of workers for Dataloader | | 38 | | `--warmup` | `10` | Number of warmup epochs | | 39 | | `--decay` | `1e-5` | Decay rate | | 40 | | `--latent` | `False` | Use latent version | | 41 | | `--ema_rate` | `0.999` | Exponential moving average rate | | 42 | | `--size` | `None` | Size of input image | | 43 | 44 | 45 | 46 | You can find out more about the parameters by checking [`util.py`](./../src/generativezoo/utils/util.py) or by running the following command on the example script: 47 | 48 | python FM.py --help 49 | 50 | ## Training 51 | 52 | You can train this model with the following command: 53 | 54 | accelerate launch FM.py --train --dataset mnist 55 | 56 | ## Sampling 57 | 58 | To sample, please provide the checkpoint: 59 | 60 | python FM.py --sample --dataset mnist --checkpoint ./../../models/FlowMatching/FM_mnist.pt 61 | 62 | ## Outlier Detection 63 | 64 | Outlier Detection is performed by using the NLL scores generated by the model: 65 | 66 | python FM.py --outlier_detection --dataset mnist --out_dataset fashionmnist --checkpoint ./../../models/FlowMatching/FM_mnist.pt -------------------------------------------------------------------------------- /docs/PrescribedGAN.md: -------------------------------------------------------------------------------- 1 | # Prescribed Generative Adversarial Networks (PresGANs) 2 | 3 | PresGANs add noise to the output of a density network and optimize an entropy-regularized adversarial loss. The added noise renders tractable approximations of the predictive log-likelihood and stabilizes the training procedure. The entropy regularizer encourages PresGANs to capture all the modes of the data distribution. 4 | 5 | ## Parameters 6 | 7 | | Argument | Description | Default | Choices | 8 | |---------------------------|----------------------------------------------------|----------|----------------------------------------------------------------------------------------------------------| 9 | | `--train` | Train model | `False` | | 10 | | `--sample` | Sample from model | `False` | | 11 | | `--dataset` | Dataset name | `mnist` | `mnist`, `cifar10`, `fashionmnist`, `chestmnist`, `octmnist`, `tissuemnist`, `pneumoniamnist`, `svhn`, `tinyimagenet`, `cifar100`, `places365`, `dtd`, `imagenet` | 12 | | `--no_wandb` | Disable Wandb | `False` | | 13 | | `--nz` | Size of the latent z vector | `100` | | 14 | | `--ngf` | | `64` | | 15 | | `--ndf` | | `64` | | 16 | | `--batch_size` | Input batch size | `64` | | 17 | | `--n_epochs` | Number of epochs to train for | `100` | | 18 | | `--lrD` | Learning rate for discriminator | `0.0002` | | 19 | | `--lrG` | Learning rate for generator | `0.0002` | | 20 | | `--lrE` | Learning rate | `0.0002` | | 21 | | `--beta1` | Beta1 for adam | `0.5` | | 22 | | `--checkpoint` | Checkpoint file for generator | `None` | | 23 | | `--discriminator_checkpoint` | Checkpoint file for discriminator | `None` | | 24 | | `--sigma_checkpoint` | File for logsigma for the generator | `None` | | 25 | | `--num_gen_images` | Number of images to generate for inspection | `16` | | 26 | | `--sigma_lr` | Generator variance | `0.0002` | | 27 | | `--lambda_` | Entropy coefficient | `0.01` | | 28 | | `--sigma_min` | Min value for sigma | `0.01` | | 29 | | `--sigma_max` | Max value for sigma | `0.3` | | 30 | | `--logsigma_init` | Initial value for log_sigma_sian | `-1.0` | | 31 | | `--num_samples_posterior` | Number of samples from posterior | `2` | | 32 | | `--burn_in` | Hmc burn in | `2` | | 33 | | `--leapfrog_steps` | Number of leap frog steps for hmc | `5` | | 34 | | `--flag_adapt` | `0` or `1` | `1` | | 35 | | `--delta` | Delta for hmc | `1.0` | | 36 | | `--hmc_learning_rate` | Lr for hmc | `0.02` | | 37 | | `--hmc_opt_accept` | Hmc optimal acceptance rate | `0.67` | | 38 | | `--stepsize_num` | Initial value for hmc stepsize | `1.0` | | 39 | | `--restrict_sigma` | Whether to restrict sigma or not | `0` | | 40 | | `--sample_and_save_freq` | Sample and save frequency | `5` | | 41 | | `--outlier_detection` | Outlier detection | `False` | | 42 | | `--out_dataset` | Outlier dataset name | `fashionmnist` | `mnist`, `cifar10`, `fashionmnist`, `chestmnist`, `octmnist`, `tissuemnist`, `pneumoniamnist`, `svhn`, `tinyimagenet`, `cifar100`, `places365`, `dtd`, `imagenet` | 43 | | `--num_workers` | Number of workers for Dataloader | `0` | | 44 | 45 | You can find out more about the parameters by checking [`util.py`](./../src/generativezoo/utils/util.py) or by running the following command on the example script: 46 | 47 | python PresGAN.py --help 48 | 49 | ## Training 50 | 51 | The PresGAN can be trained in a similar fashion to other GANs in the zoo: 52 | 53 | python PresGAN.py --train --dataset tinyimagenet --restrict_sigma 1 --sigma_min 1e-3 --sigma_max 0.3 --lambda 5e-4 --nz 1024 54 | 55 | ## Sampling 56 | 57 | For sampling you must provide the generator checkpoint: 58 | 59 | python PresGAN.py --sample --dataset tinyimagenet --nz 1024 --checkpoint ./../../models/PrescribedGAN/PresGAN_tinyimagenet.pt 60 | 61 | ## Outlier Detection 62 | 63 | To perform outlier detection you must provide the discriminator checkpoint and the sigma checkpoint: 64 | 65 | python PresGAN.py --sample --dataset tinyimagenet --out_dataset cifar10 --nz 1024 --discriminator_checkpoint ./../../models/PrescribedGAN/PresDisc_tinyimagenet.pt --sigma_checkpoint ./../../models/PrescribedGAN/PresSigma_tinyimagenet.pt 66 | 67 | --------------------------------------------------------------------------------