├── .gitignore ├── README.md ├── app.py ├── assets ├── beauty_makeup.txt ├── gradio.png ├── mt_makeup.txt └── nomakeup.txt ├── config.py ├── configs ├── base.yaml ├── makeup.yaml ├── model_256_256.yaml ├── model_256_256_eval_only.yaml ├── model_64_64.yaml └── non_makeup.yaml ├── data └── mt_text_anno.json ├── dataset ├── __init__.py ├── makeup_dataset.py ├── prefetcher.py └── text_dataset.py ├── dnnlib ├── __init__.py └── util.py ├── docs ├── index.html └── static │ ├── css │ ├── bulma-carousel.min.css │ ├── bulma-slider.min.css │ ├── bulma.css.map.txt │ ├── bulma.min.css │ ├── fontawesome.all.min.css │ └── index.css │ ├── images │ ├── all_in_one_vis.svg │ ├── makeup_to_non.svg │ ├── non_to_makeup.svg │ ├── pipeline.svg │ ├── scale_comp_combine.svg │ ├── teaser.svg │ ├── text-example.svg │ └── text_modification.svg │ ├── js │ ├── bulma-carousel.js │ ├── bulma-carousel.min.js │ ├── bulma-slider.js │ ├── bulma-slider.min.js │ ├── fontawesome.all.min.js │ └── index.js │ └── poster.pdf ├── generate_text_editing.py ├── generate_transfer.py ├── generate_translation.py ├── main.py ├── misc ├── __init__.py ├── compute_classification.py ├── compute_metrics.py ├── compute_removal_metrics.py ├── compute_text_metrics.py ├── constant.py ├── convert_beauty_face.py ├── meter.py └── morphing.py ├── modeling ├── __init__.py ├── generate.py ├── loss.py ├── model.py ├── scheduler.py ├── text_translation.py └── translation.py ├── requirements.txt ├── script └── train_text_to_image.sh ├── sd_training └── train_text_to_image.py ├── torch_fidelity ├── __init__.py ├── datasets.py ├── defaults.py ├── deprecations.py ├── feature_extractor_base.py ├── feature_extractor_clip.py ├── feature_extractor_dinov2.py ├── feature_extractor_inceptionv3.py ├── feature_extractor_vgg16.py ├── fidelity.py ├── generative_model_base.py ├── generative_model_modulewrapper.py ├── generative_model_onnx.py ├── helpers.py ├── interpolate_compat_tensorflow.py ├── metric_fid.py ├── metric_isc.py ├── metric_kid.py ├── metric_ppl.py ├── metric_prc.py ├── metrics.py ├── noise.py ├── registry.py ├── sample_similarity_base.py ├── sample_similarity_lpips.py ├── utils.py ├── utils_torch.py ├── utils_torchvision.py └── version.py └── torch_utils ├── __init__.py ├── distributed.py ├── misc.py ├── persistence.py ├── torch_dct.py └── training_stats.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .huskyrc.json 3 | out 4 | log.log 5 | **/node_modules 6 | *.pyc 7 | *.vsix 8 | envVars.txt 9 | **/.vscode/.ropeproject/** 10 | **/testFiles/**/.cache/** 11 | *.noseids 12 | .nyc_output 13 | .vscode-test 14 | __pycache__ 15 | npm-debug.log 16 | **/.mypy_cache/** 17 | !yarn.lock 18 | coverage/ 19 | cucumber-report.json 20 | **/.vscode-test/** 21 | **/.vscode test/** 22 | **/.vscode-smoke/** 23 | **/.venv*/ 24 | port.txt 25 | precommit.hook 26 | pythonFiles/lib/** 27 | pythonFiles/get-pip.py 28 | debug_coverage*/** 29 | languageServer/** 30 | languageServer.*/** 31 | bin/** 32 | obj/** 33 | .pytest_cache 34 | tmp/** 35 | .python-version 36 | .vs/ 37 | test-results*.xml 38 | xunit-test-results.xml 39 | build/ci/performance/performance-results.json 40 | !build/ 41 | debug*.log 42 | debugpy*.log 43 | pydevd*.log 44 | nodeLanguageServer/** 45 | nodeLanguageServer.*/** 46 | dist/** 47 | # translation files 48 | *.xlf 49 | package.nls.*.json 50 | l10n/ 51 | runs/ 52 | generate/ 53 | wandb/ 54 | .env 55 | *.ipynb 56 | generate_outputs -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 |

MAD: Makeup All-in-One with Cross-Domain Diffusion Model

4 | 5 |

A unified cross-domain diffusion model for various makeup tasks

6 | 7 | 8 | 9 | License: Apache2.0 10 |

11 | 12 | Pipeline Image
13 | 14 |
15 | 16 | > Bo-Kai Ruan, Hong-Han Shuai 17 | > 18 | > * Contact: Bo-Kai Ruan 19 | > * [arXiv paper](https://arxiv.org/abs/2504.02545) | [Project Website](https://basiclab.github.io/MAD) 20 | 21 | ## 🚀 A. Installation 22 | 23 | ### Step 1: Create Environment 24 | 25 | * Ubuntu 22.04 with Python ≥ 3.10 (tested with GPU using CUDA 11.8) 26 | 27 | ```shell 28 | conda create --name mad python=3.10 -y 29 | conda activate mad 30 | ``` 31 | 32 | ### Step 2: Install Dependencies 33 | 34 | ```shell 35 | conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.8 -c pytorch -c nvidia -y 36 | conda install xformers -c xformers -y 37 | pip install -r requirements.txt 38 | 39 | # Weights for landmarks 40 | wget http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2 41 | bzip2 -d shape_predictor_68_face_landmarks.dat.bz2 && mkdir weights && mv shape_predictor_68_face_landmarks.dat weights 42 | ``` 43 | 44 | ### Step 3: Prepare the Dataset 45 | 46 | The following table provides download links for the datasets: 47 | 48 | | Dataset | Link | 49 | | ------------------ | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | 50 | | MT Dataset | [all](https://drive.google.com/file/d/18UlvYDL6UGZ2rs0yaDsSzoUlw8KI5ABY/view) | 51 | | BeautyFace Dataset | [images](https://drive.google.com/file/d/1mhoopmi7OlsClOuKocjldGbTYnyDzNMc/view?usp=sharing), [parsing map](https://drive.google.com/file/d/1WgadvcV1pUtEMCYxjwWBledEQfDbadn7/view?usp=sharing) | 52 | 53 | We recommend unzipping and placing the datasets in the same folder with the following structure: 54 | 55 | ```plaintext 56 | 📦 data 57 | ┣ 📂 mtdataset 58 | ┃ ┣ 📂 images 59 | ┃ ┃ ┣ 📂 makeup 60 | ┃ ┃ ┗ 📂 non-makeup 61 | ┃ ┣ 📂 parsing 62 | ┃ ┃ ┣ 📂 makeup 63 | ┃ ┃ ┗ 📂 non-makeup 64 | ┣ 📂 beautyface 65 | ┃ ┣ 📂 images 66 | ┃ ┗ 📂 parsing 67 | ┗ ... 68 | ``` 69 | 70 | Run `misc/convert_beauty_face.py` to convert the parsing maps for the BeautyFace dataset: 71 | 72 | ```shell 73 | python misc/convert_beauty_face.py --original data/beautyface/parsing --output data/beautyface/parsing 74 | ``` 75 | 76 | We also provide the labeling text dataset [here](data/mt_text_anno.json). 77 | 78 | ## 📦 B. Usage 79 | 80 | The pretrained weight is uploaded to [Hugging Face](https://huggingface.co/Justin900/MAD) 🤗. 81 | 82 | ### B.1 Training a Model 83 | 84 | * With our model 85 | 86 | ```shell 87 | # Single GPU 88 | python main.py --config configs/model_256_256.yaml 89 | 90 | # Multi-GPU 91 | accelerate launch --multi_gpu --num_processes={NUM_OF_GPU} main.py --config configs/model_256_256.yaml 92 | ``` 93 | 94 | * With stable diffusion 95 | 96 | ```shell 97 | ./script/train_text_to_image.sh 98 | ``` 99 | 100 | ### B.2 Beauty Filter or Makeup Removal 101 | 102 | To use the beauty filter or perform makeup removal, create a `.txt` file listing the images. Here's an example: 103 | 104 | ```plaintext 105 | makeup/xxxx1.jpg 106 | makeup/xxxx2.jpg 107 | ``` 108 | 109 | Use the `source-label` and `target-label` arguments to choose between beauty filtering or makeup removal. `0` is for makeup images and `1` is for non-makeup images. 110 | 111 | For makeup removal: 112 | 113 | ```shell 114 | python generate_translation.py \ 115 | --config configs/model_256_256.yaml \ 116 | --save-folder removal_results \ 117 | --source-root data/mtdataset/images \ 118 | --source-list assets/mt_makeup.txt \ 119 | --source-label 0 \ 120 | --target-label 1 \ 121 | --num-process {NUM_PROCESS} \ 122 | --opts MODEL.PRETRAINED Justin900/MAD 123 | ``` 124 | 125 | ### B.3 Makeup Transfer 126 | 127 | For makeup transfer, prepare two `.txt` files: one for source images and one for target images. Example: 128 | 129 | ```plaintext 130 | # File 1 | # File 2 131 | makeup/xxxx1.jpg | non-makeup/xxxx1.jpg 132 | makeup/xxxx2.jpg | non-makeup/xxxx2.jpg 133 | ... | ... 134 | ``` 135 | 136 | To apply makeup transfer: 137 | 138 | ```shell 139 | python generate_transfer.py \ 140 | --config configs/model_256_256.yaml \ 141 | --save-folder transfer_result \ 142 | --source-root data/mtdataset/images \ 143 | --target-root data/beautyface/images \ 144 | --source-list assets/nomakeup.txt \ 145 | --target-list assets/beauty_makeup.txt \ 146 | --source-label 1 \ 147 | --target-label 0 \ 148 | --num-process {NUM_PROCESS} \ 149 | --inpainting \ 150 | --cam \ 151 | --opts MODEL.PRETRAINED Justin900/MAD 152 | ``` 153 | 154 | ### B.4 Text Modification 155 | 156 | For text modification, prepare a JSON file: 157 | 158 | ``` 159 | [ 160 | {"image": "xxx.jpg", "style": "makeup with xxx"} 161 | ... 162 | ] 163 | ``` 164 | 165 | ```shell 166 | python generate_text_editing.py \ 167 | --save-folder text_editing_results \ 168 | --source-root data/mtdataset/images \ 169 | --source-list assets/text_editing.json \ 170 | --num-process {NUM_PROCESS} \ 171 | --model-path Justin900/MAD 172 | ``` 173 | 174 | ## 🎨 C. Web UI 175 | 176 | Users can start the web UI by add access our designed UI with [gradio](https://github.com/gradio-app/gradio) from `localhost:7860`: 177 | 178 | ``` 179 | python app.py 180 | ``` 181 | 182 | ![gradio](assets/gradio.png) 183 | 184 | ## Citation 185 | 186 | ```bibtex 187 | @article{ruan2025mad, 188 | title={MAD: Makeup All-in-One with Cross-Domain Diffusion Model}, 189 | author={Ruan, Bo-Kai and Shuai, Hong-Han}, 190 | journal={arXiv preprint arXiv:2504.02545}, 191 | year={2025} 192 | } 193 | ``` 194 | -------------------------------------------------------------------------------- /assets/beauty_makeup.txt: -------------------------------------------------------------------------------- 1 | 3_49.png 2 | 709_20_09.png 3 | -------------------------------------------------------------------------------- /assets/gradio.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basiclab/MAD/f8b0b012eb3886274df41da9dee0eed6e07e3a34/assets/gradio.png -------------------------------------------------------------------------------- /assets/mt_makeup.txt: -------------------------------------------------------------------------------- 1 | makeup/vHX44.png 2 | makeup/vFG436.png 3 | -------------------------------------------------------------------------------- /assets/nomakeup.txt: -------------------------------------------------------------------------------- 1 | non-makeup/xfsy_0458.png 2 | non-makeup/vSYYZ388.png 3 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import pprint 3 | 4 | from colorama import Fore, Style 5 | from tabulate import tabulate 6 | from yacs.config import CfgNode as CN 7 | 8 | 9 | def create_cfg(): 10 | cfg = CN() 11 | cfg._BASE_ = None 12 | cfg.PROJECT_NAME = "Makeup Transfer with Diffusion" 13 | cfg.PROJECT_DIR = None 14 | 15 | # ##### Model setup ##### 16 | cfg.MODEL = CN() 17 | cfg.MODEL.IN_CHANNELS = 3 18 | cfg.MODEL.OUT_CHANNELS = cfg.MODEL.IN_CHANNELS 19 | cfg.MODEL.LAYERS_PER_BLOCK = 2 20 | cfg.MODEL.BASE_DIM = 128 21 | cfg.MODEL.LAYER_SCALE = [1, 1, 2, 2, 4, 4] 22 | cfg.MODEL.PRETRAINED = None # "pretrained/pretrained.pkl" 23 | cfg.MODEL.LABEL_DIM = 0 24 | cfg.MODEL.DOWN_BLOCK_TYPE = (["CrossAttnDownBlock2D"] * (len(cfg.MODEL.LAYER_SCALE) - 1)) + [ 25 | "DownBlock2D" 26 | ] 27 | cfg.MODEL.UP_BLOCK_TYPE = (["CrossAttnUpBlock2D"] * (len(cfg.MODEL.LAYER_SCALE) - 1)) + [ 28 | "UpBlock2D" 29 | ] 30 | 31 | # ###### Training set ###### 32 | cfg.TRAIN = CN() 33 | cfg.TRAIN.MAKEUP = None 34 | 35 | # Log and save 36 | cfg.TRAIN.RESUME = None 37 | cfg.TRAIN.IMAGE_SIZE = 256 38 | cfg.TRAIN.LOG_INTERVAL = 20 39 | cfg.TRAIN.SAVE_INTERVAL = 10000 40 | cfg.TRAIN.SAMPLE_INTERVAL = 10000 41 | cfg.TRAIN.ROOT = None 42 | cfg.TRAIN.TEXT_LABEL_PATH = None 43 | 44 | # Training iteration 45 | cfg.TRAIN.BATCH_SIZE = 2 46 | cfg.TRAIN.NUM_WORKERS = 4 47 | cfg.TRAIN.MAX_ITER = 350000 48 | cfg.TRAIN.GRADIENT_ACCUMULATION_STEPS = 16 49 | cfg.TRAIN.MIXED_PRECISION = "no" 50 | cfg.TRAIN.GRAD_NORM = 1.0 51 | 52 | # EMA setup 53 | cfg.TRAIN.EMA_MAX_DECAY = 0.9999 54 | cfg.TRAIN.EMA_INV_GAMMA = 1.0 55 | cfg.TRAIN.EMA_POWER = 0.75 56 | 57 | # Optimizer 58 | cfg.TRAIN.LR = 0.0001 59 | cfg.TRAIN.LR_WARMUP = 1000 60 | 61 | # Diffusion setup 62 | cfg.TRAIN.TIME_STEPS = 1000 63 | cfg.TRAIN.SAMPLE_STEPS = cfg.TRAIN.TIME_STEPS 64 | cfg.TRAIN.NOISE_SCHEDULER = CN() 65 | # ///// for linear start \\\\\\\ 66 | cfg.TRAIN.NOISE_SCHEDULER.BETA_START = 1e-4 67 | cfg.TRAIN.NOISE_SCHEDULER.BETA_END = 0.02 68 | # ///// for linear end \\\\\\\ 69 | cfg.TRAIN.NOISE_SCHEDULER.TYPE = "linear" 70 | cfg.TRAIN.NOISE_SCHEDULER.PRED_TYPE = "epsilon" 71 | 72 | # ======= Evaluation set ======= 73 | cfg.EVAL = CN() 74 | cfg.EVAL.BATCH_SIZE = 4 75 | cfg.EVAL.SAMPLE_STEPS = 1000 76 | cfg.EVAL.ETA = 0.01 77 | cfg.EVAL.REFINE_STEPS = 0 78 | cfg.EVAL.REFINE_ITERATIONS = 0 79 | cfg.EVAL.SCHEDULER = "ddpm" 80 | 81 | return cfg 82 | 83 | 84 | def merge_possible_with_base(cfg: CN, config_path): 85 | with open(config_path, "r") as f: 86 | new_cfg = cfg.load_cfg(f) 87 | if "_BASE_" in new_cfg: 88 | cfg.merge_from_file(osp.join(osp.dirname(config_path), new_cfg._BASE_)) 89 | cfg.merge_from_other_cfg(new_cfg) 90 | 91 | 92 | def split_into(v): 93 | res = "(\n" 94 | for item in v: 95 | res += f" {item},\n" 96 | res += ")" 97 | return res 98 | 99 | 100 | def pretty_print_cfg(cfg): 101 | def _indent(s_, num_spaces): 102 | s = s_.split("\n") 103 | if len(s) == 1: 104 | return s_ 105 | first = s.pop(0) 106 | s = [(num_spaces * " ") + line for line in s] 107 | s = "\n".join(s) 108 | s = first + "\n" + s 109 | return s 110 | 111 | r = "" 112 | s = [] 113 | for k, v in sorted(cfg.items()): 114 | seperator = "\n" if isinstance(v, CN) else " " 115 | attr_str = "{}:{}{}".format( 116 | str(k), 117 | seperator, 118 | pretty_print_cfg(v) if isinstance(v, CN) else pprint.pformat(v), 119 | ) 120 | attr_str = _indent(attr_str, 2) 121 | s.append(attr_str) 122 | r += "\n".join(s) 123 | return r 124 | 125 | 126 | def show_config(cfg): 127 | table = tabulate( 128 | {"Configuration": [pretty_print_cfg(cfg)]}, headers="keys", tablefmt="fancy_grid" 129 | ) 130 | print(f"{Fore.BLUE}", end="") 131 | print(table) 132 | print(f"{Style.RESET_ALL}", end="") 133 | -------------------------------------------------------------------------------- /configs/base.yaml: -------------------------------------------------------------------------------- 1 | TRAIN: 2 | ROOT: ["data/mtdataset/images"] 3 | TEXT_LABEL_PATH: "data/mt_text_anno.json" 4 | MIXED_PRECISION: "fp16" -------------------------------------------------------------------------------- /configs/makeup.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: base.yaml 2 | PROJECT_DIR: runs/makeup 3 | TRAIN: 4 | MAKEUP: true -------------------------------------------------------------------------------- /configs/model_256_256.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: base.yaml 2 | PROJECT_DIR: runs/mixup_256_256 3 | MODEL: 4 | LABEL_DIM: 2 5 | TRAIN: 6 | BATCH_SIZE: 4 7 | GRADIENT_ACCUMULATION_STEPS: 8 8 | MAX_ITER: 700000 -------------------------------------------------------------------------------- /configs/model_256_256_eval_only.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: base.yaml 2 | PROJECT_DIR: runs/for_eval_only 3 | MODEL: 4 | LABEL_DIM: 2 5 | TRAIN: 6 | ROOT: ["data/mtdataset/images", "data/beautyface/images"] 7 | BATCH_SIZE: 4 8 | GRADIENT_ACCUMULATION_STEPS: 8 9 | MAX_ITER: 350000 -------------------------------------------------------------------------------- /configs/model_64_64.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: base.yaml 2 | PROJECT_DIR: runs/mixup_64_64 3 | TRAIN: 4 | IMAGE_SIZE: 64 5 | BATCH_SIZE: 16 6 | GRADIENT_ACCUMULATION_STEPS: 4 7 | MAX_ITER: 700000 8 | MODEL: 9 | BASE_DIM: 192 10 | LAYER_SCALE: [1, 2, 3, 4] 11 | LAYERS_PER_BLOCK: 3 12 | LABEL_DIM: 2 13 | -------------------------------------------------------------------------------- /configs/non_makeup.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: base.yaml 2 | PROJECT_DIR: runs/non_makeup 3 | TRAIN: 4 | MAKEUP: false -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .makeup_dataset import get_makeup_loader 2 | from .prefetcher import DataPrefetcher 3 | 4 | __all__ = ["DataPrefetcher", "get_makeup_loader"] 5 | -------------------------------------------------------------------------------- /dataset/makeup_dataset.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import json 3 | import os 4 | import os.path as osp 5 | import random 6 | 7 | import torch 8 | from PIL import Image, ImageFile 9 | from sklearn.model_selection import train_test_split 10 | from transformers import CLIPTokenizer 11 | 12 | ImageFile.LOAD_TRUNCATED_IMAGES = True 13 | 14 | COMP_TO_TEXT = { 15 | "face": "skin", 16 | "eye": "eyes", 17 | } 18 | 19 | CHECK_TEXT = { 20 | "face": "skin", 21 | "lips": "lip", 22 | "eyes": "eye", 23 | } 24 | 25 | 26 | class MakeupDataset: 27 | def __init__( 28 | self, 29 | root_list, 30 | text_label_path, 31 | use_label=False, 32 | train=True, 33 | makeup=False, 34 | transforms=None, 35 | ): 36 | # Important!!!!!!!!!! 37 | # Adding beautyface or wild to `root_list` is only allowed for training a removal model for evaluation. 38 | # The removal model should not be used to generate the result, which may lead to unfair comparison. 39 | 40 | if not isinstance(root_list, list): 41 | root_list = [root_list] 42 | 43 | self.img_list = [] 44 | self.use_label = use_label 45 | self.label_to_idx = {} 46 | 47 | with open(text_label_path, "r") as f: 48 | self.text_label = json.load(f) 49 | 50 | self.tokenizer = CLIPTokenizer.from_pretrained( 51 | "CompVis/stable-diffusion-v1-4", subfolder="tokenizer" 52 | ) 53 | 54 | if use_label: 55 | for root in root_list: 56 | for catego in ["", "makeup", "non-makeup"]: 57 | if not osp.exists(osp.join(root, catego)): 58 | continue 59 | self.img_list.extend( 60 | list(glob.glob(osp.join(root, catego, "*.png"))) 61 | + list(glob.glob(osp.join(root, catego, "*.jpg"))) 62 | ) 63 | # 0: makeup, 1: non-makeup 64 | for img in self.img_list: 65 | folder_name = osp.basename(osp.dirname(img)) 66 | if folder_name == "non-makeup": 67 | self.label_to_idx[folder_name] = 1 68 | else: 69 | self.label_to_idx[folder_name] = 0 70 | else: 71 | for root in root_list: 72 | img_dir = "makeup" if makeup else "non-makeup" 73 | if not osp.exists(osp.join(root, img_dir)): 74 | continue 75 | self.img_list.extend( 76 | list(glob.glob(osp.join(root, img_dir, "*.png"))) 77 | + list(glob.glob(osp.join(root, img_dir, "*.jpg"))) 78 | ) 79 | self.img_list.sort() # ensure the order is the same on different machines 80 | 81 | # Create label list 82 | train_split, val_split = train_test_split(self.img_list, test_size=0.1, random_state=42) 83 | self.img_list = train_split if train else val_split 84 | self.transforms = transforms 85 | 86 | def __len__(self): 87 | return len(self.img_list) 88 | 89 | def __getitem__(self, idx): 90 | img_name = self.img_list[idx] 91 | img = Image.open(img_name).convert("RGB") 92 | if self.transforms: 93 | img = self.transforms(img) 94 | 95 | folder_name = osp.basename(osp.dirname(img_name)) 96 | label = self.label_to_idx[folder_name] if self.use_label else None 97 | 98 | if random.random() > 0.7: 99 | if os.path.basename(img_name) in self.text_label: 100 | all_desps = self.text_label[os.path.basename(img_name)] 101 | desp_list = [] 102 | for comp_name, desp in all_desps.items(): 103 | out = random.choice(desp).strip().lower() 104 | if CHECK_TEXT.get(comp_name, comp_name) not in out: 105 | out = f"{out} {COMP_TO_TEXT.get(comp_name, comp_name)}" 106 | desp_list.append(out) 107 | random.shuffle(desp_list) 108 | desp = "makeup with " + ", ".join(desp_list) 109 | else: 110 | desp = "no or light makeup" 111 | else: 112 | desp = "" 113 | 114 | text_inputs = self.tokenizer( 115 | desp, 116 | padding="max_length", 117 | max_length=self.tokenizer.model_max_length, 118 | truncation=True, 119 | return_tensors="pt", 120 | ) 121 | text_input_ids = text_inputs.input_ids 122 | return { 123 | "image": img, 124 | "text": text_input_ids, 125 | "label": ( 126 | torch.nn.functional.one_hot(torch.LongTensor([label]), len(self.label_to_idx)) 127 | .squeeze(0) 128 | .float() 129 | if self.use_label 130 | else None 131 | ), 132 | } 133 | 134 | 135 | def get_makeup_loader(cfg, train, transforms): 136 | dataset = MakeupDataset( 137 | cfg.TRAIN.ROOT, 138 | cfg.TRAIN.TEXT_LABEL_PATH, 139 | use_label=cfg.MODEL.LABEL_DIM > 0, 140 | train=train, 141 | makeup=cfg.TRAIN.MAKEUP, 142 | transforms=transforms, 143 | ) 144 | return torch.utils.data.DataLoader( 145 | dataset, 146 | shuffle=train, 147 | batch_size=cfg.TRAIN.BATCH_SIZE, 148 | num_workers=cfg.TRAIN.NUM_WORKERS, 149 | pin_memory=True, 150 | drop_last=True, 151 | ) 152 | -------------------------------------------------------------------------------- /dataset/prefetcher.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class DataPrefetcher: 5 | def __init__(self, loader, device): 6 | self.origin_loader = loader 7 | self.loader = iter(loader) 8 | self.stream = torch.cuda.Stream(device=self.device) 9 | self.use_gpu = torch.cuda.is_available() 10 | self.device = device 11 | self.preload() 12 | 13 | def preload(self): 14 | try: 15 | self.batch = next(self.loader) 16 | except StopIteration: 17 | self.loader = iter(self.origin_loader) 18 | self.batch = next(self.loader) 19 | if self.use_gpu: 20 | with torch.cuda.stream(self.stream): 21 | self.batch = self.batch.to(device=self.device, non_blocking=True) 22 | 23 | def next(self): 24 | if self.use_gpu: 25 | torch.cuda.current_stream().wait_stream(self.stream) 26 | batch = self.batch 27 | self.preload() 28 | return batch 29 | -------------------------------------------------------------------------------- /dataset/text_dataset.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import json 3 | import os 4 | import random 5 | 6 | import torch 7 | from PIL import Image 8 | 9 | COMP_TO_TEXT = { 10 | "face": "skin", 11 | "eye": "eyes", 12 | } 13 | 14 | CHECK_TEXT = { 15 | "face": "skin", 16 | "lips": "lip", 17 | "eyes": "eye", 18 | } 19 | 20 | 21 | class MTTextDataset(torch.utils.data.Dataset): 22 | def __init__(self, root, json_file, tokenizer, train=True, transforms=None): 23 | self.data = ( 24 | list(glob.glob(os.path.join(root, "makeup/*.png"))) 25 | + list(glob.glob(os.path.join(root, "makeup/*.jpg"))) 26 | + list(glob.glob(os.path.join(root, "no_makeup/*.png"))) 27 | + list(glob.glob(os.path.join(root, "no_makeup/*.jpg"))) 28 | ) 29 | self.train = train 30 | self.transforms = transforms 31 | self.tokenizer = tokenizer 32 | with open(json_file, "r") as f: 33 | self.desp = json.load(f) 34 | 35 | def __len__(self): 36 | return len(self.data) 37 | 38 | def __getitem__(self, idx): 39 | img_name = self.data[idx] 40 | img = Image.open(img_name).convert("RGB") 41 | img = self.transforms(img) 42 | 43 | if os.path.basename(img_name) in self.desp: 44 | all_desps = self.desp[os.path.basename(img_name)] 45 | desp_list = [] 46 | for comp_name, desp in all_desps.items(): 47 | out = random.choice(desp).strip().lower() 48 | if CHECK_TEXT.get(comp_name, comp_name) not in out: 49 | out = f"{out} {COMP_TO_TEXT.get(comp_name, comp_name)}" 50 | desp_list.append(out) 51 | random.shuffle(desp_list) 52 | desp = "makeup with " + ", ".join(desp_list) 53 | else: 54 | desp = "no or light makeup" 55 | 56 | desp = self.tokenizer( 57 | desp, 58 | max_length=self.tokenizer.model_max_length, 59 | padding="max_length", 60 | truncation=True, 61 | return_tensors="pt", 62 | ).input_ids 63 | 64 | return {"pixel_values": img, "input_ids": desp} 65 | -------------------------------------------------------------------------------- /dnnlib/__init__.py: -------------------------------------------------------------------------------- 1 | from .util import EasyDict, make_cache_dir_path 2 | 3 | __all__ = ["EasyDict", "make_cache_dir_path"] 4 | -------------------------------------------------------------------------------- /docs/static/css/bulma-carousel.min.css: -------------------------------------------------------------------------------- 1 | @-webkit-keyframes spinAround{from{-webkit-transform:rotate(0);transform:rotate(0)}to{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}@keyframes spinAround{from{-webkit-transform:rotate(0);transform:rotate(0)}to{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}.slider{position:relative;width:100%}.slider-container{display:flex;flex-wrap:nowrap;flex-direction:row;overflow:hidden;-webkit-transform:translate3d(0,0,0);transform:translate3d(0,0,0);min-height:100%}.slider-container.is-vertical{flex-direction:column}.slider-container .slider-item{flex:none}.slider-container .slider-item .image.is-covered img{-o-object-fit:cover;object-fit:cover;-o-object-position:center center;object-position:center center;height:100%;width:100%}.slider-container .slider-item .video-container{height:0;padding-bottom:0;padding-top:56.25%;margin:0;position:relative}.slider-container .slider-item .video-container.is-1by1,.slider-container .slider-item .video-container.is-square{padding-top:100%}.slider-container .slider-item .video-container.is-4by3{padding-top:75%}.slider-container .slider-item .video-container.is-21by9{padding-top:42.857143%}.slider-container .slider-item .video-container embed,.slider-container .slider-item .video-container iframe,.slider-container .slider-item .video-container object{position:absolute;top:0;left:0;width:100%!important;height:100%!important}.slider-navigation-next,.slider-navigation-previous{display:flex;justify-content:center;align-items:center;position:absolute;width:42px;height:42px;background:#fff center center no-repeat;background-size:20px 20px;border:1px solid #fff;border-radius:25091983px;box-shadow:0 2px 5px #3232321a;top:50%;margin-top:-20px;left:0;cursor:pointer;transition:opacity .3s,-webkit-transform .3s;transition:transform .3s,opacity .3s;transition:transform .3s,opacity .3s,-webkit-transform .3s}.slider-navigation-next:hover,.slider-navigation-previous:hover{-webkit-transform:scale(1.2);transform:scale(1.2)}.slider-navigation-next.is-hidden,.slider-navigation-previous.is-hidden{display:none;opacity:0}.slider-navigation-next svg,.slider-navigation-previous svg{width:25%}.slider-navigation-next{left:auto;right:0;background:#fff center center no-repeat;background-size:20px 20px}.slider-pagination{display:none;justify-content:center;align-items:center;position:absolute;bottom:0;left:0;right:0;padding:.5rem 1rem;text-align:center}.slider-pagination .slider-page{background:#fff;width:10px;height:10px;border-radius:25091983px;display:inline-block;margin:0 3px;box-shadow:0 2px 5px #3232321a;transition:-webkit-transform .3s;transition:transform .3s;transition:transform .3s,-webkit-transform .3s;cursor:pointer}.slider-pagination .slider-page.is-active,.slider-pagination .slider-page:hover{-webkit-transform:scale(1.4);transform:scale(1.4)}@media screen and (min-width:800px){.slider-pagination{display:flex}}.hero.has-carousel{position:relative}.hero.has-carousel+.hero-body,.hero.has-carousel+.hero-footer,.hero.has-carousel+.hero-head{z-index:10;overflow:hidden}.hero.has-carousel .hero-carousel{position:absolute;top:0;left:0;bottom:0;right:0;height:auto;border:none;margin:auto;padding:0;z-index:0}.hero.has-carousel .hero-carousel .slider{width:100%;max-width:100%;overflow:hidden;height:100%!important;max-height:100%;z-index:0}.hero.has-carousel .hero-carousel .slider .has-background{max-height:100%}.hero.has-carousel .hero-carousel .slider .has-background .is-background{-o-object-fit:cover;object-fit:cover;-o-object-position:center center;object-position:center center;height:100%;width:100%}.hero.has-carousel .hero-body{margin:0 3rem;z-index:10} -------------------------------------------------------------------------------- /docs/static/css/index.css: -------------------------------------------------------------------------------- 1 | body { 2 | font-family: 'Noto Sans', sans-serif; 3 | } 4 | 5 | 6 | .footer .icon-link { 7 | font-size: 25px; 8 | color: #000; 9 | } 10 | 11 | .footer { 12 | padding-bottom: 24px; 13 | } 14 | 15 | .link-block a { 16 | margin-top: 5px; 17 | margin-bottom: 5px; 18 | } 19 | 20 | .dnerf { 21 | font-variant: small-caps; 22 | } 23 | 24 | 25 | .teaser .hero-body { 26 | padding-top: 0; 27 | padding-bottom: 3rem; 28 | } 29 | 30 | .teaser { 31 | font-family: 'Google Sans', sans-serif; 32 | } 33 | 34 | 35 | .publication-title { 36 | } 37 | 38 | .publication-banner { 39 | max-height: parent; 40 | 41 | } 42 | 43 | .publication-banner video { 44 | position: relative; 45 | left: auto; 46 | top: auto; 47 | transform: none; 48 | object-fit: fit; 49 | } 50 | 51 | .publication-header .hero-body { 52 | } 53 | 54 | .publication-title { 55 | font-family: 'Google Sans', sans-serif; 56 | } 57 | 58 | .publication-authors { 59 | font-family: 'Google Sans', sans-serif; 60 | } 61 | 62 | .publication-venue { 63 | color: #555; 64 | width: fit-content; 65 | font-weight: bold; 66 | } 67 | 68 | .publication-awards { 69 | color: #ff3860; 70 | width: fit-content; 71 | font-weight: bolder; 72 | } 73 | 74 | .publication-authors { 75 | } 76 | 77 | .publication-authors a { 78 | color: hsl(204, 86%, 53%) !important; 79 | } 80 | 81 | .publication-authors a:hover { 82 | text-decoration: underline; 83 | } 84 | 85 | .author-block { 86 | display: inline-block; 87 | } 88 | 89 | .publication-banner img { 90 | } 91 | 92 | .publication-authors { 93 | /*color: #4286f4;*/ 94 | } 95 | 96 | .publication-video { 97 | position: relative; 98 | width: 100%; 99 | height: 0; 100 | padding-bottom: 56.25%; 101 | 102 | overflow: hidden; 103 | border-radius: 10px !important; 104 | } 105 | 106 | .publication-video iframe { 107 | position: absolute; 108 | top: 0; 109 | left: 0; 110 | width: 100%; 111 | height: 100%; 112 | } 113 | 114 | .publication-body img { 115 | } 116 | 117 | .results-carousel { 118 | overflow: hidden; 119 | } 120 | 121 | .results-carousel .item { 122 | margin: 5px; 123 | overflow: hidden; 124 | border: 1px solid #bbb; 125 | border-radius: 10px; 126 | padding: 0; 127 | font-size: 0; 128 | } 129 | 130 | .results-carousel video { 131 | margin: 0; 132 | } 133 | 134 | 135 | .interpolation-panel { 136 | background: #f5f5f5; 137 | border-radius: 10px; 138 | } 139 | 140 | .interpolation-panel .interpolation-image { 141 | width: 100%; 142 | border-radius: 5px; 143 | } 144 | 145 | .interpolation-video-column { 146 | } 147 | 148 | .interpolation-panel .slider { 149 | margin: 0 !important; 150 | } 151 | 152 | .interpolation-panel .slider { 153 | margin: 0 !important; 154 | } 155 | 156 | #interpolation-image-wrapper { 157 | width: 100%; 158 | } 159 | #interpolation-image-wrapper img { 160 | border-radius: 5px; 161 | } 162 | 163 | td:nth-child(1), th[colspan]:nth-child(2), th[colspan]:nth-child(3), th:nth-child(4), th:nth-child(7) , th:nth-child(10) { 164 | border-right: 1px solid rgb(200, 200, 200) !important; 165 | } 166 | 167 | th[colspan]{ 168 | text-align: center !important; /* Center align text in header cells */ 169 | } 170 | th[rowspan] { 171 | vertical-align: middle !important; /* Center align text vertically in header cells with rowspan */ 172 | } 173 | 174 | th, td { 175 | text-align: right !important; /* Center align text in header cells */ 176 | } 177 | 178 | -------------------------------------------------------------------------------- /docs/static/js/bulma-slider.min.js: -------------------------------------------------------------------------------- 1 | !function(t,e){"object"==typeof exports&&"object"==typeof module?module.exports=e():"function"==typeof define&&define.amd?define([],e):"object"==typeof exports?exports.bulmaSlider=e():t.bulmaSlider=e()}("undefined"!=typeof self?self:this,function(){return function(n){var r={};function i(t){if(r[t])return r[t].exports;var e=r[t]={i:t,l:!1,exports:{}};return n[t].call(e.exports,e,e.exports,i),e.l=!0,e.exports}return i.m=n,i.c=r,i.d=function(t,e,n){i.o(t,e)||Object.defineProperty(t,e,{configurable:!1,enumerable:!0,get:n})},i.n=function(t){var e=t&&t.__esModule?function(){return t.default}:function(){return t};return i.d(e,"a",e),e},i.o=function(t,e){return Object.prototype.hasOwnProperty.call(t,e)},i.p="",i(i.s=0)}([function(t,e,n){"use strict";Object.defineProperty(e,"__esModule",{value:!0}),n.d(e,"isString",function(){return l});var r=n(1),i=Object.assign||function(t){for(var e=1;e=l.length&&(s=!0)):s=!0),s&&(t.once&&(u[e]=null),t.callback(r))});-1!==u.indexOf(null);)u.splice(u.indexOf(null),1)}}]),e}();e.a=i}]).default}); -------------------------------------------------------------------------------- /docs/static/js/index.js: -------------------------------------------------------------------------------- 1 | window.HELP_IMPROVE_VIDEOJS = false; 2 | 3 | $(document).ready(function() { 4 | // Check for click events on the navbar burger icon 5 | $(".navbar-burger").click(function() { 6 | // Toggle the "is-active" class on both the "navbar-burger" and the "navbar-menu" 7 | $(".navbar-burger").toggleClass("is-active"); 8 | $(".navbar-menu").toggleClass("is-active"); 9 | 10 | }); 11 | 12 | var options = { 13 | slidesToScroll: 1, 14 | slidesToShow: 3, 15 | loop: true, 16 | infinite: true, 17 | autoplay: false, 18 | autoplaySpeed: 3000, 19 | } 20 | 21 | // Initialize all div with carousel class 22 | var carousels = bulmaCarousel.attach('.carousel', options); 23 | 24 | // Loop on each carousel initialized 25 | for(var i = 0; i < carousels.length; i++) { 26 | // Add listener to event 27 | carousels[i].on('before:show', state => { 28 | console.log(state); 29 | }); 30 | } 31 | 32 | // Access to bulmaCarousel instance of an element 33 | var element = document.querySelector('#my-element'); 34 | if (element && element.bulmaCarousel) { 35 | // bulmaCarousel instance is available as element.bulmaCarousel 36 | element.bulmaCarousel.on('before-show', function(state) { 37 | console.log(state); 38 | }); 39 | } 40 | 41 | /*var player = document.getElementById('interpolation-video'); 42 | player.addEventListener('loadedmetadata', function() { 43 | $('#interpolation-slider').on('input', function(event) { 44 | console.log(this.value, player.duration); 45 | player.currentTime = player.duration / 100 * this.value; 46 | }) 47 | }, false);*/ 48 | preloadInterpolationImages(); 49 | 50 | $('#interpolation-slider').on('input', function(event) { 51 | setInterpolationImage(this.value); 52 | }); 53 | setInterpolationImage(0); 54 | $('#interpolation-slider').prop('max', NUM_INTERP_FRAMES - 1); 55 | 56 | bulmaSlider.attach(); 57 | 58 | }) 59 | -------------------------------------------------------------------------------- /docs/static/poster.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/basiclab/MAD/f8b0b012eb3886274df41da9dee0eed6e07e3a34/docs/static/poster.pdf -------------------------------------------------------------------------------- /generate_text_editing.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | from typing import List, Optional, Tuple 5 | 6 | from joblib import Parallel, delayed 7 | from loguru import logger 8 | from tqdm import tqdm 9 | 10 | from modeling.text_translation import TextTranslationDiffusion 11 | 12 | 13 | def copy_parameters(from_parameters, to_parameters): 14 | to_parameters = list(to_parameters) 15 | assert len(from_parameters) == len(to_parameters) 16 | for s_param, param in zip(from_parameters, to_parameters): 17 | param.data.copy_(s_param.to(param.device).data) 18 | 19 | 20 | def parse_args(): 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument("--save-folder", default="batch_images", type=str) 23 | parser.add_argument("--source-root", required=True, type=str) 24 | parser.add_argument("--source-list", required=True, type=str) 25 | parser.add_argument("--num-process", default=1, type=int) 26 | parser.add_argument("--num-of-step", default=180, type=int) 27 | parser.add_argument("--img-size", default=512, type=int) 28 | parser.add_argument("--model-path", default=None, type=str) 29 | parser.add_argument("--scheduler", default="ddpm", type=str) 30 | parser.add_argument("--sample-steps", default=1000, type=int) 31 | return parser.parse_args() 32 | 33 | 34 | def generate_image( 35 | img_size: int, 36 | save_folder: str, 37 | source_list: List[Tuple[str, str]], 38 | offset: int, 39 | device: str, 40 | num_of_step: int, 41 | scheduler: str, 42 | sample_steps: int, 43 | model_path: Optional[str] = None, 44 | ): 45 | diffuser = TextTranslationDiffusion( 46 | img_size=img_size, 47 | scheduler=scheduler, 48 | device=device, 49 | model_path=model_path, 50 | sample_steps=sample_steps, 51 | ) 52 | os.makedirs(args.save_folder, exist_ok=True) 53 | 54 | progress_bar = tqdm(total=len(source_list), position=int(device.split(":")[-1])) 55 | count_error = 0 56 | 57 | for idx, (source_image, source_mask, editing_prompt) in enumerate(source_list): 58 | save_image_name = os.path.join(save_folder, f"pred_{idx + offset}.png") 59 | if os.path.exists(save_image_name): 60 | progress_bar.update(1) 61 | continue 62 | if source_mask.endswith("jpg"): 63 | source_mask = source_mask.replace("jpg", "png") 64 | try: 65 | editing_result = diffuser.modify_with_text( 66 | image=source_image, 67 | mask=source_mask, 68 | prompt=[editing_prompt], 69 | start_from_step=num_of_step, 70 | guidance_scale=15, 71 | ) 72 | except Exception as e: 73 | logger.error(str(e)) 74 | count_error += 1 75 | continue 76 | save_image = editing_result[0] 77 | save_image.save(save_image_name) 78 | progress_bar.update(1) 79 | progress_bar.close() 80 | 81 | if count_error != 0: 82 | print(f"Error in {device}: {count_error}") 83 | 84 | 85 | if __name__ == "__main__": 86 | args = parse_args() 87 | 88 | with open(args.source_list, "r") as f: 89 | data = json.load(f) 90 | source_list = [] 91 | for info in data.items(): 92 | image_path = os.path.join(args.source_root, info["image"]) 93 | mask_path = image_path.replace("images", "parsing") 94 | source_list.append((image_path, mask_path, info["style"])) 95 | 96 | task_per_process = len(source_list) // args.num_process 97 | Parallel(n_jobs=args.num_process)( 98 | delayed(generate_image)( 99 | args.img_size, 100 | args.save_folder, 101 | source_list=source_list[ 102 | gpu_idx 103 | * task_per_process : ( 104 | ((gpu_idx + 1) * task_per_process) 105 | if gpu_idx < args.num_process - 1 106 | else len(source_list) 107 | ) 108 | ], 109 | offset=gpu_idx * task_per_process, 110 | device=f"cuda:{gpu_idx}", 111 | num_of_step=args.num_of_step, 112 | model_path=args.model_path, 113 | scheduler=args.scheduler, 114 | sample_steps=args.sample_steps, 115 | ) 116 | for gpu_idx in range(args.num_process) 117 | ) 118 | -------------------------------------------------------------------------------- /generate_transfer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from typing import List, Tuple 4 | 5 | import torch 6 | from joblib import Parallel, delayed 7 | from loguru import logger 8 | from PIL import Image 9 | from tqdm import tqdm 10 | 11 | from config import create_cfg, merge_possible_with_base, show_config 12 | from modeling import build_model 13 | from modeling.translation import TranslationDiffusion 14 | 15 | 16 | def parse_args(): 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--config", default=None, type=str) 19 | parser.add_argument("--save-folder", default="batch_images", type=str) 20 | parser.add_argument("--source-root", required=True, type=str) 21 | parser.add_argument("--target-root", required=True, type=str) 22 | parser.add_argument("--source-list", required=True, type=str) 23 | parser.add_argument("--target-list", required=True, type=str) 24 | parser.add_argument("--source-label", required=True, type=int) 25 | parser.add_argument("--target-label", required=True, type=int) 26 | parser.add_argument("--num-process", default=1, type=int) 27 | parser.add_argument("--inpainting", action="store_true", default=False) 28 | parser.add_argument("--cam", action="store_true", default=False) 29 | parser.add_argument("--opts", nargs=argparse.REMAINDER, default=None, type=str) 30 | return parser.parse_args() 31 | 32 | 33 | def generate_image( 34 | cfg, 35 | save_folder: str, 36 | source_list: List[Tuple[str, str]], 37 | target_list: List[Tuple[str, str]], 38 | source_label: int, 39 | target_label: int, 40 | offset: int, 41 | device: str, 42 | inpainting: bool, 43 | use_cam: bool, 44 | ): 45 | torch.cuda.set_device(device) 46 | model = build_model(cfg).to(device) 47 | model.eval() 48 | 49 | diffuser = TranslationDiffusion(cfg, device) 50 | os.makedirs(args.save_folder, exist_ok=True) 51 | 52 | count_error = 0 53 | with tqdm(total=len(source_list), position=int(device.split(":")[-1])) as progress_bar: 54 | for idx, ((source_image, source_mask), (target_image, target_mask)) in enumerate( 55 | zip(source_list, target_list) 56 | ): 57 | save_image_name = os.path.join(save_folder, f"pred_{idx + offset}.png") 58 | if os.path.exists(save_image_name): 59 | progress_bar.update(1) 60 | continue 61 | if source_mask.endswith("jpg"): 62 | source_mask = source_mask.replace("jpg", "png") 63 | if target_mask.endswith("jpg"): 64 | target_mask = target_mask.replace("jpg", "png") 65 | 66 | try: 67 | transfer_result = diffuser.image_translation( 68 | source_model=model, 69 | target_model=model, 70 | source_image=source_image, 71 | target_image=target_image, 72 | source_class_label=source_label, 73 | target_class_label=target_label, 74 | source_parsing_mask=source_mask, 75 | target_parsing_mask=target_mask, 76 | use_morphing=True, 77 | use_encode_eps=True, 78 | use_cam=use_cam, 79 | inpainting=inpainting, 80 | ) 81 | except Exception as e: 82 | logger.error(f"Error in {device}: {e}") 83 | count_error += 1 84 | continue 85 | save_image = Image.fromarray( 86 | (transfer_result[0].permute(1, 2, 0).cpu().numpy() * 255).astype("uint8") 87 | ) 88 | save_image.save(save_image_name) 89 | progress_bar.update(1) 90 | if count_error != 0: 91 | print(f"Error in {device}: {count_error}") 92 | 93 | 94 | if __name__ == "__main__": 95 | args = parse_args() 96 | cfg = create_cfg() 97 | if args.config: 98 | merge_possible_with_base(cfg, args.config) 99 | if args.opts: 100 | cfg.merge_from_list(args.opts) 101 | show_config(cfg) 102 | 103 | source_list = [] 104 | target_list = [] 105 | 106 | with open(args.source_list, "r") as f_source, open(args.target_list, "r") as f_target: 107 | source_lines = f_source.readlines() 108 | target_lines = f_target.readlines() 109 | for idx in range(len(source_lines)): 110 | if source_lines[idx].strip() == "" or target_lines[idx].strip() == "": 111 | continue 112 | source_path = os.path.join(args.source_root, source_lines[idx].strip()) 113 | source_list.append((source_path, source_path.replace("images", "parsing"))) 114 | target_path = os.path.join(args.target_root, target_lines[idx].strip()) 115 | target_list.append((target_path, target_path.replace("images", "parsing"))) 116 | 117 | assert len(source_list) == len(target_list), "Source and target list should have same length" 118 | task_per_process = len(source_list) // args.num_process 119 | Parallel(n_jobs=args.num_process)( 120 | delayed(generate_image)( 121 | cfg, 122 | args.save_folder, 123 | source_list=source_list[ 124 | gpu_idx 125 | * task_per_process : ( 126 | ((gpu_idx + 1) * task_per_process) 127 | if gpu_idx < args.num_process - 1 128 | else len(source_list) 129 | ) 130 | ], 131 | target_list=target_list[ 132 | gpu_idx 133 | * task_per_process : ( 134 | ((gpu_idx + 1) * task_per_process) 135 | if gpu_idx < args.num_process - 1 136 | else len(target_list) 137 | ) 138 | ], 139 | source_label=args.source_label, 140 | target_label=args.target_label, 141 | offset=gpu_idx * task_per_process, 142 | device=f"cuda:{gpu_idx}", 143 | inpainting=args.inpainting, 144 | use_cam=args.cam, 145 | ) 146 | for gpu_idx in range(args.num_process) 147 | ) 148 | -------------------------------------------------------------------------------- /generate_translation.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from typing import List, Tuple 4 | 5 | import torch 6 | from joblib import Parallel, delayed 7 | from loguru import logger 8 | from PIL import Image 9 | from tqdm import tqdm 10 | 11 | from config import create_cfg, merge_possible_with_base, show_config 12 | from modeling import build_model 13 | from modeling.translation import TranslationDiffusion 14 | 15 | 16 | def parse_args(): 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--config", default=None, type=str) 19 | parser.add_argument("--save-folder", default="batch_images", type=str) 20 | parser.add_argument("--source-root", required=True, type=str) 21 | parser.add_argument("--source-list", required=True, type=str) 22 | parser.add_argument("--source-label", required=True, type=int) 23 | parser.add_argument("--target-label", required=True, type=int) 24 | parser.add_argument("--num-process", default=1, type=int) 25 | parser.add_argument("--num-of-step", default=700, type=int) 26 | parser.add_argument("--opts", nargs=argparse.REMAINDER, default=None, type=str) 27 | return parser.parse_args() 28 | 29 | 30 | def generate_image( 31 | cfg, 32 | save_folder: str, 33 | source_list: List[Tuple[str, str]], 34 | source_label: int, 35 | target_label: int, 36 | offset: int, 37 | device: str, 38 | num_of_step: int, 39 | ): 40 | torch.cuda.set_device(device) 41 | model = build_model(cfg).to(device) 42 | model.eval() 43 | 44 | diffuser = TranslationDiffusion(cfg, device) 45 | os.makedirs(args.save_folder, exist_ok=True) 46 | 47 | count_error = 0 48 | with tqdm( 49 | total=len(source_list), position=int(device.split(":")[-1]) 50 | ) as progress_bar: 51 | for idx, (source_image, source_mask) in enumerate(source_list): 52 | save_image_name = os.path.join(save_folder, f"pred_{idx + offset}.png") 53 | if os.path.exists(save_image_name): 54 | progress_bar.update(1) 55 | continue 56 | if source_mask.endswith("jpg"): 57 | source_mask = source_mask.replace("jpg", "png") 58 | try: 59 | transfer_result = diffuser.domain_translation( 60 | source_model=model, 61 | target_model=model, 62 | source_image=source_image, 63 | source_class_label=source_label, 64 | target_class_label=target_label, 65 | parsing_mask=source_mask, 66 | start_from_step=num_of_step, 67 | ) 68 | except Exception as e: 69 | logger.error(str(e)) 70 | count_error += 1 71 | continue 72 | save_image = Image.fromarray( 73 | (transfer_result[0].permute(1, 2, 0).cpu().numpy() * 255).astype( 74 | "uint8" 75 | ) 76 | ) 77 | save_image.save(save_image_name) 78 | 79 | if count_error != 0: 80 | print(f"Error in {device}: {count_error}") 81 | 82 | 83 | if __name__ == "__main__": 84 | args = parse_args() 85 | cfg = create_cfg() 86 | if args.config: 87 | merge_possible_with_base(cfg, args.config) 88 | if args.opts: 89 | cfg.merge_from_list(args.opts) 90 | show_config(cfg) 91 | 92 | source_list = [] 93 | with open(args.source_list, "r") as f: 94 | for line in f.readlines(): 95 | source_line = line.strip() 96 | image_path = os.path.join(args.source_root, source_line) 97 | mask_path = image_path.replace("images", "parsing") 98 | source_list.append((image_path, mask_path)) 99 | 100 | task_per_process = len(source_list) // args.num_process 101 | Parallel(n_jobs=args.num_process)( 102 | delayed(generate_image)( 103 | cfg, 104 | args.save_folder, 105 | source_list=source_list[ 106 | gpu_idx * task_per_process : ( 107 | ((gpu_idx + 1) * task_per_process) 108 | if gpu_idx < args.num_process - 1 109 | else len(source_list) 110 | ) 111 | ], 112 | source_label=args.source_label, 113 | target_label=args.target_label, 114 | offset=gpu_idx * task_per_process, 115 | device=f"cuda:{gpu_idx}", 116 | num_of_step=args.num_of_step, 117 | ) 118 | for gpu_idx in range(args.num_process) 119 | ) 120 | 121 | -------------------------------------------------------------------------------- /misc/__init__.py: -------------------------------------------------------------------------------- 1 | from .meter import AverageMeter, MetricMeter 2 | 3 | __all__ = ["AverageMeter", "MetricMeter"] 4 | -------------------------------------------------------------------------------- /misc/compute_classification.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import os 4 | 5 | import tabulate 6 | import torch 7 | import torch.nn as nn 8 | import torch.utils.data 9 | import torchvision.models as models 10 | from PIL import Image, ImageFile 11 | from torchvision import transforms 12 | from tqdm import tqdm 13 | 14 | ImageFile.LOAD_TRUNCATED_IMAGES = True 15 | 16 | ROOT = { 17 | "mt": "data/mtdataset/images", 18 | "mt_removal": "generate_outputs/mt_removal", 19 | "beauty": "data/beautyface/images", 20 | "beauty_removal": "generate_outputs/beauty_removal", 21 | "wild": "data/wild/images", 22 | "wild_removal": "generate_outputs/wild_removal", 23 | } 24 | 25 | 26 | class InceptionModel(torch.nn.Module): 27 | def __init__(self, num_class=1000): 28 | super(InceptionModel, self).__init__() 29 | self.Conv2d_1a_3x3 = BasicConv2d(3, 32, kernel_size=3, stride=2) 30 | self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3) 31 | self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1) 32 | self.MaxPool_1 = nn.MaxPool2d(kernel_size=3, stride=2) 33 | self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1) 34 | self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3) 35 | self.MaxPool_2 = nn.MaxPool2d(kernel_size=3, stride=2) 36 | self.Mixed_5b = InceptionA(192, pool_features=32) 37 | self.Mixed_5c = InceptionA(256, pool_features=64) 38 | self.Mixed_5d = InceptionA(288, pool_features=64) 39 | self.Mixed_6a = InceptionB(288) 40 | self.Mixed_6b = InceptionC(768, channels_7x7=128) 41 | self.Mixed_6c = InceptionC(768, channels_7x7=160) 42 | self.Mixed_6d = InceptionC(768, channels_7x7=160) 43 | self.Mixed_6e = InceptionC(768, channels_7x7=192) 44 | self.Mixed_7a = InceptionD(768) 45 | self.Mixed_7b = InceptionE_1(1280) 46 | self.Mixed_7c = InceptionE_2(2048) 47 | self.AvgPool = nn.AdaptiveAvgPool2d(output_size=(1, 1)) 48 | self.fc = nn.Linear(2048, num_class) 49 | 50 | def forward(self, x): 51 | x = interpolate_bilinear_2d_like_tensorflow1x( 52 | x, 53 | size=(299, 299), 54 | align_corners=False, 55 | ) 56 | x = self.Conv2d_1a_3x3(x) 57 | x = self.Conv2d_2a_3x3(x) 58 | x = self.Conv2d_2b_3x3(x) 59 | x = self.MaxPool_1(x) 60 | x = self.Conv2d_3b_1x1(x) 61 | x = self.Conv2d_4a_3x3(x) 62 | x = self.MaxPool_2(x) 63 | x = self.Mixed_5b(x) 64 | x = self.Mixed_5c(x) 65 | x = self.Mixed_5d(x) 66 | x = self.Mixed_6a(x) 67 | x = self.Mixed_6b(x) 68 | x = self.Mixed_6c(x) 69 | x = self.Mixed_6d(x) 70 | x = self.Mixed_6e(x) 71 | x = self.Mixed_7a(x) 72 | x = self.Mixed_7b(x) 73 | x = self.Mixed_7c(x) 74 | x = self.AvgPool(x) 75 | x = torch.flatten(x, 1) 76 | x = self.fc(x) 77 | 78 | return x 79 | 80 | 81 | def get_label(image_root): 82 | all_images = [] 83 | for root in image_root: 84 | all_images.extend( 85 | list(glob.glob(os.path.join(root, "**/*.png"), recursive=True)) 86 | + list(glob.glob(os.path.join(root, "**/*.jpg"), recursive=True)) 87 | ) 88 | label = {name: idx for idx, name in enumerate(sorted(set(all_images)))} 89 | return label 90 | 91 | 92 | def main(args): 93 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 94 | label_dict = get_label(["data/mtdataset/images", "data/beautyface/images", "data/wild/images"]) 95 | print(len(label_dict)) 96 | model = InceptionModel(num_class=len(label_dict)).to(device) 97 | model.load_state_dict(torch.load("inception.pth", map_location="cpu")) 98 | model = model.to(device) 99 | model.eval() 100 | 101 | transform_list = transforms.Compose( 102 | [ 103 | transforms.Resize((224, 224), antialias=True), 104 | transforms.CenterCrop((224, 224)), 105 | transforms.ToTensor(), 106 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 107 | ] 108 | ) 109 | 110 | table = [] 111 | target_non_makeup_img = [] 112 | with open(args.non_makeup_file, "r") as f: 113 | for line in f: 114 | if line.strip() != "": 115 | target_non_makeup_img.append(os.path.join(ROOT["mt"], line.strip())) 116 | for target in tqdm(args.path_list): 117 | if not os.path.exists(target): 118 | continue 119 | target_file_sorted = sorted( 120 | list(glob.glob(f"{target}/*.png")), 121 | key=lambda x: int(os.path.basename(x).split(".")[0].split("_")[-1]), 122 | ) 123 | selected_non_makeup_image = [] 124 | for file in target_file_sorted: 125 | num = int(os.path.basename(file).split(".")[0].split("_")[-1]) 126 | selected_non_makeup_image.append(target_non_makeup_img[num]) 127 | 128 | score = 0 129 | for img_name, label_name in zip(target_file_sorted, selected_non_makeup_image): 130 | target_label = label_dict[label_name] 131 | 132 | img = Image.open(img_name).convert("RGB") 133 | img = transform_list(img).unsqueeze(0).to(device) 134 | 135 | with torch.no_grad(): 136 | output = model(img) 137 | output = nn.functional.softmax(output, dim=1)[0].cpu().tolist() 138 | score += output[target_label] 139 | score /= len(selected_non_makeup_image) 140 | table.append([target.split("/")[1], f"{score:.3f}"]) 141 | print(tabulate.tabulate(table, headers=["Approach", "Acc"], tablefmt="grid")) 142 | 143 | 144 | if __name__ == "__main__": 145 | parser = argparse.ArgumentParser(description="Compute KID") 146 | parser.add_argument( 147 | "--non-makeup-file", 148 | help="File denoting the order of makeup image", 149 | default="data/nomakeup_test_mt.txt", 150 | ) 151 | parser.add_argument( 152 | "--path-list", 153 | help="List of path for evaluation", 154 | nargs="+", 155 | type=str, 156 | required=True, 157 | ) 158 | parser.add_argument( 159 | "--type", 160 | help="Type of dataset", 161 | choices=["mt", "beauty", "wild"], 162 | default="mt", 163 | ) 164 | args = parser.parse_args() 165 | main(args) 166 | -------------------------------------------------------------------------------- /misc/compute_metrics.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import sys 4 | 5 | from tqdm import tqdm 6 | 7 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) 8 | 9 | import argparse 10 | import glob 11 | 12 | import numpy as np 13 | import tabulate 14 | import torch 15 | import torch.nn 16 | import torchvision.transforms.functional as F 17 | from PIL import Image, ImageFile 18 | from torchvision import transforms 19 | 20 | import torch_fidelity 21 | 22 | ImageFile.LOAD_TRUNCATED_IMAGES = True 23 | 24 | ROOT = { 25 | "mt": "data/mtdataset/images", 26 | "mt_removal": "generate_outputs/mt_removal", 27 | "beauty": "data/beautyface/images", 28 | "beauty_removal": "generate_outputs/beauty_removal", 29 | "wild": "data/wild/images", 30 | "wild_removal": "generate_outputs/wild_removal", 31 | } 32 | 33 | 34 | def set_seed(seed: int = 385832): 35 | random.seed(seed) 36 | np.random.seed(seed) 37 | torch.manual_seed(seed) 38 | torch.cuda.manual_seed(seed) 39 | torch.cuda.manual_seed_all(seed) 40 | torch.backends.cudnn.deterministic = True 41 | torch.backends.cudnn.benchmark = False 42 | 43 | 44 | class TransformPILtoRGBTensor: 45 | def __call__(self, img): 46 | return F.pil_to_tensor(img) 47 | 48 | 49 | class ImagesPathDataset(torch.utils.data.Dataset): 50 | def __init__(self, files, img_size=224): 51 | self.files = files 52 | self.transforms = transforms.Compose( 53 | [ 54 | transforms.Resize((img_size, img_size), antialias=True), 55 | transforms.CenterCrop((img_size, img_size)), 56 | TransformPILtoRGBTensor(), 57 | ] 58 | ) 59 | 60 | def __len__(self): 61 | return len(self.files) 62 | 63 | def __getitem__(self, i): 64 | path = self.files[i] 65 | img = Image.open(path).convert("RGB") 66 | 67 | img = self.transforms(img) 68 | return img 69 | 70 | 71 | class CheckImagesPathDataset(torch.utils.data.Dataset): 72 | def __init__( 73 | self, 74 | root, 75 | ori_file=None, 76 | target_files=None, 77 | order_files=None, 78 | img_size=224, 79 | ): 80 | self.files = [] 81 | if ori_file is None: 82 | for file in order_files: 83 | num = int(os.path.basename(file).split(".")[0].split("_")[-1]) 84 | self.files.append(os.path.join(root, f"pred_{num}.png")) 85 | else: 86 | with open(ori_file, "r") as f: 87 | line_of_file = f.readlines() 88 | for file in target_files: 89 | num = int(os.path.basename(file).split(".")[0].split("_")[-1]) 90 | self.files.append(os.path.join(root, line_of_file[num].strip())) 91 | self.transforms = transforms.Compose( 92 | [ 93 | transforms.Resize((img_size, img_size), antialias=True), 94 | transforms.CenterCrop((img_size, img_size)), 95 | TransformPILtoRGBTensor(), 96 | ] 97 | ) 98 | 99 | def __len__(self): 100 | return len(self.files) 101 | 102 | def __getitem__(self, i): 103 | path = self.files[i] 104 | img = Image.open(path).convert("RGB") 105 | img = self.transforms(img) 106 | return img 107 | 108 | 109 | def main(args): 110 | table = [] 111 | target_non_makeup_img = [] 112 | with open(args.non_makeup_file, "r") as f: 113 | for line in f: 114 | target_non_makeup_img.append(os.path.join(ROOT["mt"], line.strip())) 115 | for target in tqdm(args.path_list): 116 | if not os.path.exists(target): 117 | continue 118 | target_file_sorted = sorted( 119 | list(glob.glob(f"{target}/*.png")), 120 | key=lambda x: int(os.path.basename(x).split(".")[0].split("_")[-1]), 121 | ) 122 | selected_non_makeup_image = [] 123 | for file in target_file_sorted: 124 | num = int(os.path.basename(file).split(".")[0].split("_")[-1]) 125 | selected_non_makeup_image.append(target_non_makeup_img[num]) 126 | precision = torch_fidelity.calculate_metrics( 127 | input1=ImagesPathDataset(files=target_file_sorted), 128 | input2=ImagesPathDataset(files=selected_non_makeup_image), 129 | input3=CheckImagesPathDataset( 130 | root=ROOT[args.type], 131 | ori_file=args.makeup_file, 132 | target_files=target_file_sorted, 133 | ), 134 | prc=True, 135 | device="cuda", 136 | verbose=False, 137 | cache=False, 138 | feature_extractor="vgg16", 139 | feature_extractor_weights_path="vgg.pth", 140 | )["precision"] 141 | 142 | recall = torch_fidelity.calculate_metrics( 143 | input1=ImagesPathDataset(files=target_file_sorted), 144 | input2=CheckImagesPathDataset( 145 | root=ROOT[args.type], 146 | ori_file=args.makeup_file, 147 | target_files=target_file_sorted, 148 | ), 149 | input3=CheckImagesPathDataset( # Use mt to remove original non-makeup feature 150 | root=ROOT["mt"], 151 | ori_file=args.non_makeup_file, 152 | target_files=target_file_sorted, 153 | ), 154 | input4=CheckImagesPathDataset( 155 | root=ROOT[ 156 | f"{args.type}_removal" 157 | ], # Use makeup to non-makeup image to remove non-makeup feature 158 | order_files=target_file_sorted, 159 | ), 160 | prc=True, 161 | device="cuda", 162 | cache=False, 163 | verbose=False, 164 | feature_extractor="vgg16", 165 | feature_extractor_weights_path="vgg.pth", 166 | )["recall"] 167 | 168 | kid = torch_fidelity.calculate_metrics( 169 | input1=ImagesPathDataset(files=target_file_sorted), 170 | input2=CheckImagesPathDataset( 171 | root=ROOT[args.type], 172 | ori_file=args.makeup_file, 173 | target_files=target_file_sorted, 174 | ), 175 | input3=CheckImagesPathDataset( # Use mt to remove original non-makeup feature 176 | root=ROOT["mt"], 177 | ori_file=args.non_makeup_file, 178 | target_files=target_file_sorted, 179 | ), 180 | input4=CheckImagesPathDataset( 181 | root=ROOT[ 182 | f"{args.type}_removal" 183 | ], # Use makeup to non-makeup image to remove non-makeup feature 184 | order_files=target_file_sorted, 185 | ), 186 | kid=True, 187 | device="cuda", 188 | cache=False, 189 | verbose=False, 190 | feature_extractor="inception-v3-compat", 191 | feature_extractor_weights_path="inception.pth", 192 | kid_subset_size=min(1000, len(target_file_sorted)), 193 | )["kernel_inception_distance_mean"] 194 | 195 | table.append( 196 | [ 197 | target.split("/")[1], 198 | f"{precision:.3f}", 199 | f"{recall:.3f}", 200 | f"{kid:.3f}", 201 | ] 202 | ) 203 | print( 204 | tabulate.tabulate( 205 | table, headers=["Approach", "Precision", "Recall", "KID"], tablefmt="fancy_grid" 206 | ) 207 | ) 208 | 209 | 210 | if __name__ == "__main__": 211 | parser = argparse.ArgumentParser(description="Compute KID") 212 | parser.add_argument( 213 | "--non-makeup-file", 214 | help="File denoting the order of makeup image", 215 | default="data/nomakeup_test_mt.txt", 216 | ) 217 | parser.add_argument( 218 | "--makeup-file", 219 | help="File denoting the order of makeup image", 220 | default="data/makeup_test_mt.txt", 221 | ) 222 | parser.add_argument( 223 | "--path-list", 224 | help="List of path for evaluation", 225 | nargs="+", 226 | type=str, 227 | required=True, 228 | ) 229 | parser.add_argument( 230 | "--type", 231 | help="Type of dataset", 232 | choices=["mt", "beauty", "wild"], 233 | default="mt", 234 | ) 235 | args = parser.parse_args() 236 | set_seed() 237 | main(args) 238 | -------------------------------------------------------------------------------- /misc/compute_removal_metrics.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import os 4 | 5 | import cv2 6 | import skimage.metrics 7 | import tabulate 8 | from tqdm import tqdm 9 | 10 | ROOT = { 11 | "mt": "data/mtdataset/images", 12 | } 13 | 14 | 15 | def main(args): 16 | table = [] 17 | for target in tqdm(args.path_list): 18 | if not os.path.exists(target): 19 | raise FileNotFoundError(f"{target} not found") 20 | target_file_sorted = sorted( 21 | list(glob.glob(f"{target}/pred_*.png")), 22 | key=lambda x: int(os.path.basename(x).split(".")[0].split("_")[-1]), 23 | ) 24 | target_non_makeup_img = [] 25 | with open(args.origin_file, "r") as f: 26 | for line in f: 27 | if line.strip() != "": 28 | target_non_makeup_img.append(os.path.join(ROOT["mt"], line.strip())) 29 | 30 | total_ssim = 0 31 | total_psnr = 0 32 | for file in target_file_sorted: 33 | idx = int(os.path.basename(file).split(".")[0].rsplit("_", 1)[1]) 34 | pred = cv2.imread(file) 35 | non_makeup = cv2.imread(target_non_makeup_img[idx]) 36 | pred = cv2.resize(pred, (non_makeup.shape[1], non_makeup.shape[0])) 37 | ssim = skimage.metrics.structural_similarity( 38 | cv2.cvtColor(pred, cv2.COLOR_BGR2GRAY), 39 | cv2.cvtColor(non_makeup, cv2.COLOR_BGR2GRAY), 40 | ) 41 | psnr = skimage.metrics.peak_signal_noise_ratio( 42 | cv2.cvtColor(pred, cv2.COLOR_BGR2GRAY), 43 | cv2.cvtColor(non_makeup, cv2.COLOR_BGR2GRAY), 44 | ) 45 | total_ssim += ssim 46 | total_psnr += psnr 47 | table.append( 48 | [ 49 | target, 50 | total_ssim / len(target_file_sorted), 51 | total_psnr / len(target_file_sorted), 52 | ] 53 | ) 54 | print(tabulate.tabulate(table, headers=["Appraoch", "SSIM", "PSNR"], tablefmt="fancy_grid")) 55 | 56 | 57 | if __name__ == "__main__": 58 | parser = argparse.ArgumentParser(description="Compute KID") 59 | parser.add_argument( 60 | "--origin-file", 61 | help="File denoting the order of makeup image", 62 | default="data/nomakeup_test_mt.txt", 63 | ) 64 | parser.add_argument( 65 | "--path-list", 66 | help="List of path for evaluation", 67 | nargs="+", 68 | type=str, 69 | required=True, 70 | ) 71 | args = parser.parse_args() 72 | main(args) 73 | -------------------------------------------------------------------------------- /misc/compute_text_metrics.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | import clip 6 | import tabulate 7 | import torch 8 | from PIL import Image 9 | from tqdm import tqdm 10 | 11 | 12 | def parse_args(): 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--test-file", type=str, required=True) 15 | parser.add_argument("--model-path", type=str, required=True) 16 | parser.add_argument("--orig-root", type=str, required=True) 17 | parser.add_argument("--test-dir", type=str, nargs="+", required=True) 18 | return parser.parse_args() 19 | 20 | 21 | def compute_clip_text(model, preprocess, image_dir, test_meta): 22 | all_images = [] 23 | all_texts = [] 24 | for idx, val in enumerate(tqdm(list(test_meta.values()))): 25 | all_images.append(preprocess(Image.open(os.path.join(image_dir, f"pred_{idx}.png")))) 26 | all_texts.append(f"makeup with {', '.join(val['style'])}") 27 | image = torch.stack(all_images).to(device) 28 | text = clip.tokenize(all_texts).to(device) 29 | 30 | with torch.no_grad(): 31 | image_features = model.encode_image(image) 32 | text_features = model.encode_text(text) 33 | 34 | image_features = image_features / image_features.norm(dim=1, keepdim=True) 35 | text_features = text_features / text_features.norm(dim=1, keepdim=True) 36 | similarity = torch.nn.functional.cosine_similarity(image_features, text_features).mean() 37 | return similarity 38 | 39 | 40 | def compute_image_similarity(model, preprocess, image_dir, orig_dir, test_meta): 41 | all_ori_images = [] 42 | all_result_images = [] 43 | for idx, val in enumerate(tqdm(list(test_meta.values()))): 44 | all_ori_images.append(preprocess(Image.open(os.path.join(orig_dir, val["nonmakeup"])))) 45 | all_result_images.append(preprocess(Image.open(os.path.join(image_dir, f"pred_{idx}.png")))) 46 | 47 | ori_image = torch.stack(all_ori_images).to(device) 48 | result_image = torch.stack(all_result_images).to(device) 49 | 50 | with torch.no_grad(): 51 | ori_image_features = model.encode_image(ori_image) 52 | result_image_features = model.encode_image(result_image) 53 | 54 | ori_image_features = ori_image_features / ori_image_features.norm(dim=1, keepdim=True) 55 | result_image_features = result_image_features / result_image_features.norm( 56 | dim=1, keepdim=True 57 | ) 58 | similarity = torch.nn.functional.cosine_similarity( 59 | ori_image_features, result_image_features 60 | ).mean() 61 | return similarity 62 | 63 | 64 | def compute_style_similarity(model, preprocess, image_dir, orig_dir, test_meta): 65 | all_ori_images = [] 66 | all_result_images = [] 67 | for idx, key in enumerate(tqdm(list(test_meta.keys()))): 68 | all_ori_images.append(preprocess(Image.open(os.path.join(orig_dir, key)))) 69 | all_result_images.append(preprocess(Image.open(os.path.join(image_dir, f"pred_{idx}.png")))) 70 | 71 | ori_image = torch.stack(all_ori_images).to(device) 72 | result_image = torch.stack(all_result_images).to(device) 73 | 74 | with torch.no_grad(): 75 | ori_image_features = model.encode_image(ori_image) 76 | result_image_features = model.encode_image(result_image) 77 | 78 | ori_image_features = ori_image_features / ori_image_features.norm(dim=1, keepdim=True) 79 | result_image_features = result_image_features / result_image_features.norm( 80 | dim=1, keepdim=True 81 | ) 82 | similarity = torch.nn.functional.cosine_similarity( 83 | ori_image_features, result_image_features 84 | ).mean() 85 | return similarity 86 | 87 | 88 | if __name__ == "__main__": 89 | device = "cuda" if torch.cuda.is_available() else "cpu" 90 | 91 | args = parse_args() 92 | 93 | model, preprocess = clip.load(args.model_path, device=device) 94 | 95 | with open(args.test_file, "r") as f: 96 | test_meta = json.load(f) 97 | 98 | text_match_score = [] 99 | image_match_score = [] 100 | style_score = [] 101 | for test_dir in args.test_dir: 102 | text_match_score.append(compute_clip_text(model, preprocess, test_dir, test_meta)) 103 | image_match_score.append( 104 | compute_image_similarity(model, preprocess, test_dir, args.orig_root, test_meta) 105 | ) 106 | style_score.append( 107 | compute_style_similarity(model, preprocess, test_dir, args.orig_root, test_meta) 108 | ) 109 | 110 | print( 111 | tabulate.tabulate( 112 | [ 113 | ["CLIP text"] + text_match_score, 114 | ["CLIP image"] + image_match_score, 115 | ["CLIp style"] + style_score, 116 | ], 117 | headers=["Metric"] + [os.path.basename(dir_name) for dir_name in args.test_dir], 118 | tablefmt="pretty", 119 | ) 120 | ) 121 | -------------------------------------------------------------------------------- /misc/constant.py: -------------------------------------------------------------------------------- 1 | # fmt: off 2 | LEFT_EYE_CONTOUR = [448, 449, 450, 451, 452, 453, 464, 413, 285, 295, 282, 283, 276, 383, 265, 261] 3 | RIGHT_EYE_CONTOUR = [231, 230, 229, 228, 111, 143, 156, 46, 53, 52, 65, 55, 193, 244, 233, 232] 4 | NUM_TO_EYE_CONTOUR = { 5 | 1: RIGHT_EYE_CONTOUR, 6 | 6: LEFT_EYE_CONTOUR, 7 | } 8 | 9 | # Idx for mediapipe facial landmarks 10 | MEDIAPIPE_LIPS_IDX = [0, 267, 269, 270, 13, 14, 17, 402, 146, 405, 409, 415, 291, 37, 39, 40, 178, 308, 181, 310, 311, 312, 185, 314, 317, 318, 61, 191, 321, 324, 78, 80, 81, 82, 84, 87, 88, 91, 95, 375] 11 | MEDIAPIPE_LEFT_EYES_IDX = [384, 385, 386, 387, 388, 390, 263, 362, 398, 466, 373, 374, 249, 380, 381, 382] 12 | MEDIAPIPE_RIGHT_EYES_IDX = [160, 33, 161, 163, 133, 7, 173, 144, 145, 246, 153, 154, 155, 157, 158, 159] 13 | MEDIAPIPE_LEFT_EYEBROW_IDX = [65, 66, 70, 105, 107, 46, 52, 53, 55, 63] 14 | MEDIAPIPE_RIGHT_EYEBROW_IDX = [293, 295, 296, 300, 334, 336, 276, 282, 283, 285] 15 | MEDIAPIPE_NOSE_IDX = [1, 2, 4, 5, 6, 19, 275, 278, 294, 168, 45, 48, 440, 64, 195, 197, 326, 327, 344, 220, 94, 97, 98, 115] 16 | MEDIAPIPE_OVAL = [132, 389, 136, 10, 397, 400, 148, 149, 150, 21, 152, 284, 288, 162, 297, 172, 176, 54, 58, 323, 67, 454, 332, 338, 93, 356, 103, 361, 234, 109, 365, 379, 377, 378, 251, 127] 17 | MEDIAPIPE_LANDMARKS = MEDIAPIPE_LIPS_IDX + MEDIAPIPE_LEFT_EYES_IDX + MEDIAPIPE_RIGHT_EYES_IDX + MEDIAPIPE_LEFT_EYEBROW_IDX + MEDIAPIPE_RIGHT_EYEBROW_IDX + MEDIAPIPE_NOSE_IDX 18 | MEDIAPIPE_REMOVE_LANDMARKS = [135, 136, 169, 150, 170, 149, 140, 176, 171, 148, 175, 152, 396, 377, 369, 400, 395, 378, 394, 379, 364, 365, 367, 397, 435, 288, 361, 401, 323, 366, 454, 447, 356, 264, 368, 389, 251, 284, 54, 21, 162, 139, 127, 34, 234, 227, 93, 137, 132, 177, 215, 58, 138, 172] 19 | 20 | # Idx for dlib facial landmarks 21 | DLIB_LANDMARKS = list(range(0, 17)) 22 | -------------------------------------------------------------------------------- /misc/convert_beauty_face.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import glob 3 | import os 4 | 5 | import click 6 | import numpy as np 7 | from PIL import Image 8 | from tqdm import tqdm 9 | 10 | CONVERT_DICT = { 11 | 12: 9, # up lip 12 | 1: 4, # face 13 | 10: 8, # nose 14 | 14: 10, # neck 15 | 17: 12, # hair 16 | 4: 1, # right eye 17 | 5: 6, # left eye 18 | 2: 2, # right eyebrow 19 | 3: 7, # left eyebrow 20 | 7: 5, # right ear 21 | 8: 3, # left ear 22 | 9: 0, # ear ring 23 | 11: 11, # teeth 24 | 16: 0, # shirt 25 | } 26 | 27 | 28 | @click.command() 29 | @click.option("--original", help="Original json file", type=click.Path(exists=True), required=True) 30 | @click.option("--save_path", help="Original json file", required=True) 31 | def main(original, save_path): 32 | os.makedirs(save_path, exist_ok=True) 33 | original_mask = glob.glob(os.path.join(original, "*.png")) 34 | 35 | for mask_name in tqdm(original_mask): 36 | mask = np.array(Image.open(mask_name)) 37 | new_mask = copy.deepcopy(mask) 38 | for key, value in CONVERT_DICT.items(): 39 | new_mask[mask == key] = value 40 | new_mask = Image.fromarray(new_mask) 41 | new_mask.save(os.path.join(save_path, os.path.basename(mask_name))) 42 | 43 | 44 | if __name__ == "__main__": 45 | main() 46 | -------------------------------------------------------------------------------- /misc/meter.py: -------------------------------------------------------------------------------- 1 | """Taken from 2 | https://github.com/KaiyangZhou/deep-person-reid/blob/master/torchreid/utils/avgmeter.py 3 | """ 4 | 5 | from collections import defaultdict 6 | 7 | import torch 8 | 9 | 10 | class AverageMeter(object): 11 | def __init__(self): 12 | self.reset() 13 | 14 | def reset(self): 15 | self.val = 0 16 | self.avg = 0 17 | self.sum = 0 18 | self.count = 0 19 | 20 | def update(self, val, n=1): 21 | self.val = val 22 | self.sum += val * n 23 | self.count += n 24 | self.avg = self.sum / self.count 25 | 26 | 27 | class MetricMeter(object): 28 | def __init__(self, delimiter="\t"): 29 | self.meters = defaultdict(AverageMeter) 30 | self.delimiter = delimiter 31 | 32 | def update(self, input_dict): 33 | if input_dict is None: 34 | return 35 | 36 | if not isinstance(input_dict, dict): 37 | raise TypeError("Input to MetricMeter.update() must be a dictionary") 38 | 39 | for k, v in input_dict.items(): 40 | if isinstance(v, torch.Tensor): 41 | v = v.item() 42 | self.meters[k].update(v) 43 | 44 | def __str__(self): 45 | output_str = [] 46 | for name, meter in self.meters.items(): 47 | output_str.append(f"{name} {meter.val:.4f} ({meter.avg:.4f})") 48 | return self.delimiter.join(output_str) 49 | 50 | def get_log_dict(self): 51 | log_dict = {} 52 | for name, meter in self.meters.items(): 53 | log_dict[name] = meter.val 54 | log_dict[f"avg_{name}"] = meter.avg 55 | return log_dict 56 | -------------------------------------------------------------------------------- /misc/morphing.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from https://github.com/andrewdcampbell/face-movie/blob/master/face-movie/face_morph.py 3 | """ 4 | 5 | import copy 6 | from typing import Dict 7 | 8 | import cv2 9 | import numpy as np 10 | from scipy.spatial import Delaunay 11 | 12 | 13 | def warp_im(im, src_landmarks, dst_landmarks, triangulation, flags=cv2.INTER_LINEAR): 14 | im_out = im.copy() 15 | 16 | for i in range(len(triangulation)): 17 | src_tri = src_landmarks[triangulation[i]] 18 | dst_tri = dst_landmarks[triangulation[i]] 19 | morph_triangle(im, im_out, src_tri, dst_tri, flags) 20 | 21 | return im_out 22 | 23 | 24 | def morph_triangle(im, im_out, src_tri, dst_tri, flags): 25 | sr = cv2.boundingRect(np.float32([src_tri])) 26 | dr = cv2.boundingRect(np.float32([dst_tri])) 27 | cropped_src_tri = [(src_tri[i][0] - sr[0], src_tri[i][1] - sr[1]) for i in range(3)] 28 | cropped_dst_tri = [(dst_tri[i][0] - dr[0], dst_tri[i][1] - dr[1]) for i in range(3)] 29 | 30 | mask = np.zeros((dr[3], dr[2], 3), dtype=np.float32) 31 | cv2.fillConvexPoly(mask, np.int32(cropped_dst_tri), (1.0, 1.0, 1.0), 16, 0) 32 | 33 | cropped_im = im[sr[1] : sr[1] + sr[3], sr[0] : sr[0] + sr[2]] 34 | 35 | size = (dr[2], dr[3]) 36 | warpImage1 = affine_transform(cropped_im, cropped_src_tri, cropped_dst_tri, size, flags) 37 | 38 | im_out[dr[1] : dr[1] + dr[3], dr[0] : dr[0] + dr[2]] = ( 39 | im_out[dr[1] : dr[1] + dr[3], dr[0] : dr[0] + dr[2]] * (1 - mask) + warpImage1 * mask 40 | ) 41 | 42 | 43 | def affine_transform(src, src_tri, dst_tri, size, flags=cv2.INTER_LINEAR): 44 | M = cv2.getAffineTransform(np.float32(src_tri), np.float32(dst_tri)) 45 | dst = cv2.warpAffine(src, M, size, flags=flags, borderMode=cv2.BORDER_REPLICATE) 46 | return dst 47 | 48 | 49 | def morph_seq( 50 | source_img: np.ndarray, 51 | target_img: np.ndarray, 52 | source_landmarks: np.ndarray, 53 | target_landmarks: np.ndarray, 54 | source_mask: np.ndarray, 55 | target_mask: np.ndarray, 56 | comp_list: Dict[int, float], 57 | ): 58 | source_img = np.float32(source_img) 59 | target_img = np.float32(target_img) 60 | source_mask = np.repeat(np.float32(source_mask[..., None]), 3, axis=-1) 61 | 62 | triangulation = Delaunay(source_landmarks).simplices 63 | warped_source_mask = warp_im( 64 | source_mask, source_landmarks, target_landmarks, triangulation, flags=cv2.INTER_NEAREST 65 | )[..., 0] 66 | warped_source = warp_im(source_img, source_landmarks, target_landmarks, triangulation) 67 | # warped_source[warped_source == 0] = target_img[warped_source == 0] 68 | 69 | un_covered_mask = np.zeros_like(target_mask, dtype="bool") 70 | blended = copy.deepcopy(target_img) 71 | for comp, alpha in comp_list.items(): 72 | if comp == 9 or comp == 13: 73 | target_comp_mask = np.logical_or(target_mask == 9, target_mask == 13) 74 | source_comp_mask = np.logical_or(warped_source_mask == 9, warped_source_mask == 13) 75 | else: 76 | target_comp_mask = target_mask == comp 77 | source_comp_mask = warped_source_mask == comp 78 | 79 | union_comp_mask = target_comp_mask & source_comp_mask 80 | 81 | blended[union_comp_mask] = ( 82 | alpha * warped_source[union_comp_mask] + (1 - alpha) * target_img[union_comp_mask] 83 | ) 84 | target_comp_mask[union_comp_mask] = 0 85 | un_covered_mask[target_comp_mask] = 1 86 | 87 | return blended.astype("uint8"), un_covered_mask 88 | -------------------------------------------------------------------------------- /modeling/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import build_diffuser_model as build_model 2 | 3 | __all__ = ["build_model"] 4 | -------------------------------------------------------------------------------- /modeling/generate.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import tqdm 4 | from PIL import Image 5 | 6 | 7 | @torch.no_grad() 8 | def generate_image_grid( 9 | net: EDMPrecond, 10 | dest_path, 11 | seed=0, 12 | gridw=4, 13 | gridh=4, 14 | device=torch.device("cuda"), 15 | num_steps=256, 16 | sigma_min=0.002, 17 | sigma_max=80, 18 | rho=7, 19 | S_churn=40, 20 | S_min=0.05, 21 | S_max=50, 22 | S_noise=1, 23 | ): 24 | net.eval() 25 | batch_size = gridw * gridh 26 | torch.manual_seed(seed) 27 | 28 | # Pick latents and labels. 29 | print(f"Generating {batch_size} images...") 30 | latents = torch.randn( 31 | [batch_size, net.img_channels, net.img_resolution, net.img_resolution], device=device 32 | ) 33 | class_labels = None 34 | if net.label_dim: 35 | class_labels = torch.eye(net.label_dim, device=device)[ 36 | torch.randint(net.label_dim, size=[batch_size], device=device) 37 | ] 38 | 39 | # Adjust noise levels based on what's supported by the network. 40 | sigma_min = max(sigma_min, net.sigma_min) 41 | sigma_max = min(sigma_max, net.sigma_max) 42 | 43 | # Time step discretization. 44 | step_indices = torch.arange(num_steps, dtype=torch.float64, device=device) 45 | t_steps = ( 46 | sigma_max ** (1 / rho) 47 | + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho)) 48 | ) ** rho 49 | t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0 50 | 51 | # Main sampling loop. 52 | x_next = latents.to(torch.float64) * t_steps[0] 53 | for i, (t_cur, t_next) in tqdm.tqdm( 54 | list(enumerate(zip(t_steps[:-1], t_steps[1:]))), unit="step" 55 | ): # 0, ..., N-1 56 | x_cur = x_next 57 | 58 | # Increase noise temporarily. 59 | gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0 60 | t_hat = net.round_sigma(t_cur + gamma * t_cur) 61 | x_hat = x_cur + (t_hat**2 - t_cur**2).sqrt() * S_noise * torch.randn_like(x_cur) 62 | 63 | # Euler step. 64 | denoised = net(x_hat, t_hat, class_labels).to(torch.float64) 65 | d_cur = (x_hat - denoised) / t_hat 66 | x_next = x_hat + (t_next - t_hat) * d_cur 67 | 68 | # Apply 2nd order correction. 69 | if i < num_steps - 1: 70 | denoised = net(x_next, t_next, class_labels).to(torch.float64) 71 | d_prime = (x_next - denoised) / t_next 72 | x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) 73 | 74 | # Save image grid. 75 | print(f'Saving image grid to "{dest_path}"...') 76 | image = (x_next * 127.5 + 128).clip(0, 255).to(torch.uint8) 77 | image = image.reshape(gridh, gridw, *image.shape[1:]).permute(0, 3, 1, 4, 2) 78 | image = image.reshape(gridh * net.img_resolution, gridw * net.img_resolution, net.img_channels) 79 | image = image.cpu().numpy() 80 | Image.fromarray(image, "RGB").save(dest_path) 81 | print("Done.") 82 | -------------------------------------------------------------------------------- /modeling/loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Loss functions used in the paper 9 | "Elucidating the Design Space of Diffusion-Based Generative Models".""" 10 | 11 | import torch 12 | 13 | from torch_utils import persistence 14 | 15 | # ---------------------------------------------------------------------------- 16 | # Loss function corresponding to the variance preserving (VP) formulation 17 | # from the paper "Score-Based Generative Modeling through Stochastic 18 | # Differential Equations". 19 | 20 | 21 | @persistence.persistent_class 22 | class VPLoss: 23 | def __init__(self, beta_d=19.9, beta_min=0.1, epsilon_t=1e-5): 24 | self.beta_d = beta_d 25 | self.beta_min = beta_min 26 | self.epsilon_t = epsilon_t 27 | 28 | def __call__(self, net, images, labels, augment_pipe=None): 29 | rnd_uniform = torch.rand([images.shape[0], 1, 1, 1], device=images.device) 30 | sigma = self.sigma(1 + rnd_uniform * (self.epsilon_t - 1)) 31 | weight = 1 / sigma**2 32 | y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None) 33 | n = torch.randn_like(y) * sigma 34 | D_yn = net(y + n, sigma, labels, augment_labels=augment_labels) 35 | loss = weight * ((D_yn - y) ** 2) 36 | return loss 37 | 38 | def sigma(self, t): 39 | t = torch.as_tensor(t) 40 | return ((0.5 * self.beta_d * (t**2) + self.beta_min * t).exp() - 1).sqrt() 41 | 42 | 43 | # ---------------------------------------------------------------------------- 44 | # Loss function corresponding to the variance exploding (VE) formulation 45 | # from the paper "Score-Based Generative Modeling through Stochastic 46 | # Differential Equations". 47 | 48 | 49 | @persistence.persistent_class 50 | class VELoss: 51 | def __init__(self, sigma_min=0.02, sigma_max=100): 52 | self.sigma_min = sigma_min 53 | self.sigma_max = sigma_max 54 | 55 | def __call__(self, net, images, labels, augment_pipe=None): 56 | rnd_uniform = torch.rand([images.shape[0], 1, 1, 1], device=images.device) 57 | sigma = self.sigma_min * ((self.sigma_max / self.sigma_min) ** rnd_uniform) 58 | weight = 1 / sigma**2 59 | y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None) 60 | n = torch.randn_like(y) * sigma 61 | D_yn = net(y + n, sigma, labels, augment_labels=augment_labels) 62 | loss = weight * ((D_yn - y) ** 2) 63 | return loss 64 | 65 | 66 | # ---------------------------------------------------------------------------- 67 | # Improved loss function proposed in the paper "Elucidating the Design Space 68 | # of Diffusion-Based Generative Models" (EDM). 69 | 70 | 71 | @persistence.persistent_class 72 | class EDMLoss: 73 | def __init__(self, P_mean=-1.2, P_std=1.2, sigma_data=0.5): 74 | self.P_mean = P_mean 75 | self.P_std = P_std 76 | self.sigma_data = sigma_data 77 | 78 | def __call__(self, net, images, labels=None, augment_pipe=None): 79 | rnd_normal = torch.randn([images.shape[0], 1, 1, 1], device=images.device) 80 | sigma = (rnd_normal * self.P_std + self.P_mean).exp() 81 | weight = (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 82 | y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None) 83 | n = torch.randn_like(y) * sigma 84 | D_yn = net(y + n, sigma, labels, augment_labels=augment_labels) 85 | loss = weight * ((D_yn - y) ** 2) 86 | return loss 87 | 88 | 89 | # ---------------------------------------------------------------------------- 90 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==1.7.0 2 | diffusers==0.33.1 3 | transformers==4.52.3 4 | tabulate==0.9.0 5 | yacs==0.1.8 6 | scipy==1.11.1 7 | pandas==2.2.0 8 | loguru==0.7.2 9 | colorama==0.4.6 10 | scikit-learn==1.3.0 11 | einops==0.6.0 12 | dlib==19.24.2 13 | mediapipe==0.10.9 14 | scikit-image==0.21.0 15 | aim==3.24.0 16 | gradio==5.31.0 17 | gdown -------------------------------------------------------------------------------- /script/train_text_to_image.sh: -------------------------------------------------------------------------------- 1 | MODEL_NAME="runwayml/stable-diffusion-v1-5" 2 | 3 | accelerate launch --mixed_precision="fp16" sd_training/train_text_to_image.py \ 4 | --pretrained_model_name_or_path=$MODEL_NAME \ 5 | --train_data_dir data/mtdataset/images \ 6 | --train_json_file data/mt_text_anno.json \ 7 | --resolution=512 --center_crop --random_flip \ 8 | --train_batch_size=1 \ 9 | --gradient_accumulation_steps=4 \ 10 | --gradient_checkpointing \ 11 | --max_train_steps=24000 \ 12 | --checkpointing_steps=3000 \ 13 | --learning_rate=1e-05 \ 14 | --max_grad_norm=1 \ 15 | --mixed_precision="fp16" \ 16 | --enable_xformers_memory_efficient_attention \ 17 | --snr_gamma=5.0 \ 18 | --lr_scheduler="constant" --lr_warmup_steps=0 \ 19 | --output_dir="runs/sd_512_512" \ 20 | -------------------------------------------------------------------------------- /torch_fidelity/__init__.py: -------------------------------------------------------------------------------- 1 | from torch_fidelity.feature_extractor_base import FeatureExtractorBase 2 | from torch_fidelity.feature_extractor_inceptionv3 import FeatureExtractorInceptionV3 3 | from torch_fidelity.feature_extractor_clip import FeatureExtractorCLIP 4 | from torch_fidelity.generative_model_base import GenerativeModelBase 5 | from torch_fidelity.generative_model_modulewrapper import GenerativeModelModuleWrapper 6 | from torch_fidelity.generative_model_onnx import GenerativeModelONNX 7 | from torch_fidelity.metric_fid import KEY_METRIC_FID 8 | from torch_fidelity.metric_isc import KEY_METRIC_ISC_MEAN, KEY_METRIC_ISC_STD 9 | from torch_fidelity.metric_kid import KEY_METRIC_KID_MEAN, KEY_METRIC_KID_STD 10 | from torch_fidelity.metric_ppl import KEY_METRIC_PPL_MEAN, KEY_METRIC_PPL_STD, KEY_METRIC_PPL_RAW 11 | from torch_fidelity.metric_prc import KEY_METRIC_PRECISION, KEY_METRIC_RECALL, KEY_METRIC_F_SCORE 12 | from torch_fidelity.metrics import calculate_metrics 13 | from torch_fidelity.registry import ( 14 | register_dataset, 15 | register_feature_extractor, 16 | register_sample_similarity, 17 | register_noise_source, 18 | register_interpolation, 19 | ) 20 | from torch_fidelity.sample_similarity_base import SampleSimilarityBase 21 | from torch_fidelity.sample_similarity_lpips import SampleSimilarityLPIPS 22 | from torch_fidelity.version import __version__ 23 | -------------------------------------------------------------------------------- /torch_fidelity/datasets.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from contextlib import redirect_stdout 3 | 4 | import torch 5 | from PIL import Image 6 | from torch.utils.data import Dataset 7 | from torchvision.datasets import CIFAR10, STL10, CIFAR100 8 | import torchvision.transforms.functional as F 9 | 10 | from torch_fidelity.helpers import vassert 11 | 12 | 13 | class TransformPILtoRGBTensor: 14 | def __call__(self, img): 15 | vassert(type(img) is Image.Image, "Input is not a PIL.Image") 16 | return F.pil_to_tensor(img) 17 | 18 | 19 | class ImagesPathDataset(Dataset): 20 | def __init__(self, files, transforms=None): 21 | self.files = files 22 | self.transforms = TransformPILtoRGBTensor() if transforms is None else transforms 23 | 24 | def __len__(self): 25 | return len(self.files) 26 | 27 | def __getitem__(self, i): 28 | path = self.files[i] 29 | img = Image.open(path).convert("RGB") 30 | img = self.transforms(img) 31 | return img 32 | 33 | 34 | class Cifar10_RGB(CIFAR10): 35 | def __init__(self, *args, **kwargs): 36 | with redirect_stdout(sys.stderr): 37 | super().__init__(*args, **kwargs) 38 | 39 | def __getitem__(self, index): 40 | img, target = super().__getitem__(index) 41 | return img 42 | 43 | 44 | class Cifar100_RGB(CIFAR100): 45 | def __init__(self, *args, **kwargs): 46 | with redirect_stdout(sys.stderr): 47 | super().__init__(*args, **kwargs) 48 | 49 | def __getitem__(self, index): 50 | img, target = super().__getitem__(index) 51 | return img 52 | 53 | 54 | class STL10_RGB(STL10): 55 | def __init__(self, *args, **kwargs): 56 | with redirect_stdout(sys.stderr): 57 | super().__init__(*args, **kwargs) 58 | 59 | def __getitem__(self, index): 60 | img, target = super().__getitem__(index) 61 | return img 62 | 63 | 64 | class RandomlyGeneratedDataset(Dataset): 65 | def __init__(self, num_samples, *dimensions, dtype=torch.uint8, seed=2021): 66 | vassert(dtype == torch.uint8, "Unsupported dtype") 67 | rng_stash = torch.get_rng_state() 68 | try: 69 | torch.manual_seed(seed) 70 | self.imgs = torch.randint(0, 255, (num_samples, *dimensions), dtype=dtype) 71 | finally: 72 | torch.set_rng_state(rng_stash) 73 | 74 | def __len__(self): 75 | return self.imgs.shape[0] 76 | 77 | def __getitem__(self, i): 78 | return self.imgs[i] 79 | -------------------------------------------------------------------------------- /torch_fidelity/defaults.py: -------------------------------------------------------------------------------- 1 | DEFAULTS = { 2 | "input1": None, 3 | "input2": None, 4 | "input3": None, 5 | "input4": None, 6 | "cuda": True, 7 | "batch_size": 64, 8 | "isc": False, 9 | "fid": False, 10 | "kid": False, 11 | "prc": False, 12 | "ppl": False, 13 | "feature_extractor": None, 14 | "feature_layer_isc": None, 15 | "feature_layer_fid": None, 16 | "feature_layer_kid": None, 17 | "feature_layer_prc": None, 18 | "feature_extractor_weights_path": None, 19 | "feature_extractor_internal_dtype": None, 20 | "feature_extractor_compile": False, 21 | "isc_splits": 10, 22 | "kid_subsets": 100, 23 | "kid_subset_size": 1000, 24 | "kid_kernel": "poly", 25 | "kid_kernel_poly_degree": 3, 26 | "kid_kernel_poly_gamma": None, 27 | "kid_kernel_poly_coef0": 1, 28 | "kid_kernel_rbf_sigma": 10, 29 | "ppl_epsilon": 1e-4, 30 | "ppl_reduction": "mean", 31 | "ppl_sample_similarity": "lpips-vgg16", 32 | "ppl_sample_similarity_resize": 64, 33 | "ppl_sample_similarity_dtype": "uint8", 34 | "ppl_discard_percentile_lower": 1, 35 | "ppl_discard_percentile_higher": 99, 36 | "ppl_z_interp_mode": "lerp", 37 | "prc_neighborhood": 3, 38 | "prc_batch_size": 10000, 39 | "samples_shuffle": True, 40 | "samples_find_deep": False, 41 | "samples_find_ext": "png,jpg,jpeg", 42 | "samples_ext_lossy": "jpg,jpeg", 43 | "samples_resize_and_crop": 0, 44 | "datasets_root": None, 45 | "datasets_download": True, 46 | "cache_root": None, 47 | "cache": True, 48 | "input1_cache_name": None, 49 | "input1_model_z_type": "normal", 50 | "input1_model_z_size": None, 51 | "input1_model_num_classes": 0, 52 | "input1_model_num_samples": None, 53 | "input2_cache_name": None, 54 | "input2_model_z_type": "normal", 55 | "input2_model_z_size": None, 56 | "input2_model_num_classes": 0, 57 | "input2_model_num_samples": None, 58 | "input3_cache_name": None, 59 | "input3_model_z_type": "normal", 60 | "input3_model_z_size": None, 61 | "input3_model_num_classes": 0, 62 | "input3_model_num_samples": None, 63 | "input4_cache_name": None, 64 | "input4_model_z_type": "normal", 65 | "input4_model_z_size": None, 66 | "input4_model_num_classes": 0, 67 | "input4_model_num_samples": None, 68 | "rng_seed": 2020, 69 | "save_cpu_ram": False, 70 | "verbose": True, 71 | } 72 | -------------------------------------------------------------------------------- /torch_fidelity/deprecations.py: -------------------------------------------------------------------------------- 1 | DEPRECATIONS = { 2 | "kid_degree": { 3 | "new_name": "kid_kernel_poly_degree", 4 | "since": "0.4.0", 5 | "reason": "Supporting various kernel functions, including RBF, to enable computing the metric proposed in https://arxiv.org/pdf/2401.09603.pdf", 6 | }, 7 | "kid_gamma": { 8 | "new_name": "kid_kernel_poly_gamma", 9 | "since": "0.4.0", 10 | "reason": "Supporting various kernel functions, including RBF, to enable computing the metric proposed in https://arxiv.org/pdf/2401.09603.pdf", 11 | }, 12 | "kid_coef0": { 13 | "new_name": "kid_kernel_poly_coef0", 14 | "since": "0.4.0", 15 | "reason": "Supporting various kernel functions, including RBF, to enable computing the metric proposed in https://arxiv.org/pdf/2401.09603.pdf", 16 | }, 17 | } 18 | -------------------------------------------------------------------------------- /torch_fidelity/feature_extractor_base.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from torch_fidelity.helpers import vassert 4 | 5 | 6 | class FeatureExtractorBase(nn.Module): 7 | def __init__(self, name, features_list): 8 | """ 9 | Base class for feature extractors that can be used in :func:`calculate_metrics`. 10 | 11 | Args: 12 | 13 | name (str): Unique name of the subclassed feature extractor, must be the same as used in 14 | :func:`register_feature_extractor`. 15 | 16 | features_list (list): List of feature names, provided by the subclassed feature extractor. 17 | """ 18 | super(FeatureExtractorBase, self).__init__() 19 | vassert(type(name) is str, "Feature extractor name must be a string") 20 | vassert(type(features_list) in (list, tuple), "Wrong features list type") 21 | vassert( 22 | all((a in self.get_provided_features_list() for a in features_list)), 23 | f"Requested features {tuple(features_list)} are not on the list provided by the selected feature extractor " 24 | f"{self.get_provided_features_list()}", 25 | ) 26 | vassert(len(features_list) == len(set(features_list)), "Duplicate features requested") 27 | vassert(len(features_list) > 0, "No features requested") 28 | self.name = name 29 | self.features_list = features_list 30 | 31 | def get_name(self): 32 | return self.name 33 | 34 | @staticmethod 35 | def get_provided_features_list(): 36 | """ 37 | Returns a tuple of feature names, extracted by the subclassed feature extractor. 38 | """ 39 | raise NotImplementedError 40 | 41 | @staticmethod 42 | def get_default_feature_layer_for_metric(metric): 43 | """ 44 | Returns a default feature name to be used for the metric computation. 45 | """ 46 | raise NotImplementedError 47 | 48 | @staticmethod 49 | def can_be_compiled(): 50 | """ 51 | Indicates whether the subclass can be safely wrapped with torch.compile. 52 | """ 53 | raise NotImplementedError 54 | 55 | @staticmethod 56 | def get_dummy_input_for_compile(): 57 | """ 58 | Returns a dummy input for compilation 59 | """ 60 | raise NotImplementedError 61 | 62 | def get_requested_features_list(self): 63 | return self.features_list 64 | 65 | def convert_features_tuple_to_dict(self, features): 66 | # The only compound return type of the forward function amenable to JIT tracing is tuple. 67 | # This function simply helps to recover the mapping. 68 | vassert( 69 | type(features) is tuple and len(features) == len(self.features_list), 70 | "Features must be the output of forward function", 71 | ) 72 | return dict(((name, feature) for name, feature in zip(self.features_list, features))) 73 | 74 | def forward(self, input): 75 | """ 76 | Returns a tuple of tensors extracted from the `input`, in the same order as they are provided by 77 | `get_provided_features_list()`. 78 | """ 79 | raise NotImplementedError 80 | -------------------------------------------------------------------------------- /torch_fidelity/feature_extractor_dinov2.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import warnings 3 | 4 | import torch 5 | import torchvision 6 | 7 | from torch_fidelity.feature_extractor_base import FeatureExtractorBase 8 | from torch_fidelity.helpers import vassert, text_to_dtype, CleanStderr 9 | 10 | from torch_fidelity.interpolate_compat_tensorflow import interpolate_bilinear_2d_like_tensorflow1x 11 | 12 | 13 | MODEL_METADATA = { 14 | "dinov2-vit-s-14": "dinov2_vits14", # dim=384 15 | "dinov2-vit-b-14": "dinov2_vitb14", # dim=768 16 | "dinov2-vit-l-14": "dinov2_vitl14", # dim=1024 17 | "dinov2-vit-g-14": "dinov2_vitg14", # dim=1536 18 | } 19 | 20 | 21 | class FeatureExtractorDinoV2(FeatureExtractorBase): 22 | INPUT_IMAGE_SIZE = 224 23 | 24 | def __init__( 25 | self, 26 | name, 27 | features_list, 28 | feature_extractor_weights_path=None, 29 | feature_extractor_internal_dtype=None, 30 | **kwargs, 31 | ): 32 | """ 33 | DinoV2 feature extractor for 2D RGB 24bit images. 34 | 35 | Args: 36 | 37 | name (str): Unique name of the feature extractor, must be the same as used in 38 | :func:`register_feature_extractor`. 39 | 40 | features_list (list): A list of the requested feature names, which will be produced for each input. This 41 | feature extractor provides the following features: 42 | 43 | - 'dinov2' 44 | 45 | feature_extractor_weights_path (str): Path to the pretrained InceptionV3 model weights in PyTorch format. 46 | Refer to `util_convert_inception_weights` for making your own. Downloads from internet if `None`. 47 | 48 | feature_extractor_internal_dtype (str): dtype to use inside the feature extractor. Specifying it may improve 49 | numerical precision in some cases. Supported values are 'float32' (default), and 'float64'. 50 | """ 51 | super(FeatureExtractorDinoV2, self).__init__(name, features_list) 52 | vassert( 53 | feature_extractor_internal_dtype in ("float32", "float64", None), 54 | "Only 32-bit floats are supported for internal dtype of this feature extractor", 55 | ) 56 | 57 | vassert(name in MODEL_METADATA, f"Model {name} not found; available models = {list(MODEL_METADATA.keys())}") 58 | self.feature_extractor_internal_dtype = text_to_dtype(feature_extractor_internal_dtype, "float32") 59 | 60 | with CleanStderr(["xFormers not available", "Using cache found in"], sys.stderr), warnings.catch_warnings(): 61 | warnings.filterwarnings("ignore", message="xFormers is not available") 62 | if feature_extractor_weights_path is None: 63 | self.model = torch.hub.load("facebookresearch/dinov2", MODEL_METADATA[name]) 64 | else: 65 | raise NotImplementedError 66 | 67 | self.to(self.feature_extractor_internal_dtype) 68 | self.requires_grad_(False) 69 | self.eval() 70 | 71 | def forward(self, x): 72 | vassert(torch.is_tensor(x) and x.dtype == torch.uint8, "Expecting image as torch.Tensor with dtype=torch.uint8") 73 | vassert(x.dim() == 4 and x.shape[1] == 3, f"Input is not Bx3xHxW: {x.shape}") 74 | 75 | x = x.to(self.feature_extractor_internal_dtype) 76 | # N x 3 x ? x ? 77 | 78 | x = interpolate_bilinear_2d_like_tensorflow1x( 79 | x, 80 | size=(self.INPUT_IMAGE_SIZE, self.INPUT_IMAGE_SIZE), 81 | align_corners=False, 82 | ) 83 | # N x 3 x 224 x 224 84 | 85 | x = torchvision.transforms.functional.normalize( 86 | x, 87 | (255 * 0.485, 255 * 0.456, 255 * 0.406), 88 | (255 * 0.229, 255 * 0.224, 255 * 0.225), 89 | inplace=False, 90 | ) 91 | # N x 3 x 224 x 224 92 | 93 | x = self.model(x) 94 | 95 | out = { 96 | "dinov2": x.to(torch.float32), 97 | } 98 | 99 | return tuple(out[a] for a in self.features_list) 100 | 101 | @staticmethod 102 | def get_provided_features_list(): 103 | return ("dinov2",) 104 | 105 | @staticmethod 106 | def get_default_feature_layer_for_metric(metric): 107 | return { 108 | "isc": "dinov2", 109 | "fid": "dinov2", 110 | "kid": "dinov2", 111 | "prc": "dinov2", 112 | }[metric] 113 | 114 | @staticmethod 115 | def can_be_compiled(): 116 | return True 117 | 118 | @staticmethod 119 | def get_dummy_input_for_compile(): 120 | return (torch.rand([1, 3, 4, 4]) * 255).to(torch.uint8) 121 | -------------------------------------------------------------------------------- /torch_fidelity/feature_extractor_vgg16.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torchvision 4 | 5 | from torch_fidelity.feature_extractor_base import FeatureExtractorBase 6 | from torch_fidelity.helpers import text_to_dtype, vassert 7 | from torch_fidelity.interpolate_compat_tensorflow import ( 8 | interpolate_bilinear_2d_like_tensorflow1x, 9 | ) 10 | from torch_fidelity.utils_torchvision import torchvision_load_pretrained_vgg16 11 | 12 | 13 | class FeatureExtractorVGG16(FeatureExtractorBase): 14 | INPUT_IMAGE_SIZE = 224 15 | 16 | def __init__( 17 | self, 18 | name, 19 | features_list, 20 | feature_extractor_weights_path=None, 21 | feature_extractor_internal_dtype=None, 22 | **kwargs, 23 | ): 24 | """ 25 | VGG16 feature extractor for 2D RGB 24bit images. 26 | 27 | Args: 28 | 29 | name (str): Unique name of the feature extractor, must be the same as used in 30 | :func:`register_feature_extractor`. 31 | 32 | features_list (list): A list of the requested feature names, which will be produced for each input. This 33 | feature extractor provides the following features: 34 | 35 | - 'fc2' 36 | - 'fc2_relu' 37 | 38 | feature_extractor_weights_path (str): Path to the pretrained InceptionV3 model weights in PyTorch format. 39 | Refer to `util_convert_inception_weights` for making your own. Downloads from internet if `None`. 40 | 41 | feature_extractor_internal_dtype (str): dtype to use inside the feature extractor. Specifying it may improve 42 | numerical precision in some cases. Supported values are 'float32' (default), and 'float64'. 43 | """ 44 | super(FeatureExtractorVGG16, self).__init__(name, features_list) 45 | vassert( 46 | feature_extractor_internal_dtype in ("float32", "float64", None), 47 | "Only 32-bit floats are supported for internal dtype of this feature extractor", 48 | ) 49 | self.feature_extractor_internal_dtype = text_to_dtype( 50 | feature_extractor_internal_dtype, "float32" 51 | ) 52 | 53 | if feature_extractor_weights_path is None: 54 | self.model = torchvision_load_pretrained_vgg16(**kwargs) 55 | else: 56 | state_dict = torch.load(feature_extractor_weights_path) 57 | self.model = torchvision.models.vgg16() 58 | new_state_dict = {} 59 | for k, v in state_dict.items(): 60 | new_state_dict[k.replace("model.", "")] = v 61 | print(self.model.load_state_dict(new_state_dict, strict=False)) 62 | for cls_tail_id in (6, 5, 4): 63 | del self.model.classifier[cls_tail_id] 64 | 65 | self.to(self.feature_extractor_internal_dtype) 66 | self.requires_grad_(False) 67 | self.eval() 68 | 69 | def forward(self, x): 70 | vassert( 71 | torch.is_tensor(x) and x.dtype == torch.uint8, 72 | "Expecting image as torch.Tensor with dtype=torch.uint8", 73 | ) 74 | vassert(x.dim() == 4 and x.shape[1] == 3, f"Input is not Bx3xHxW: {x.shape}") 75 | features = {} 76 | remaining_features = self.features_list.copy() 77 | 78 | x = x.to(self.feature_extractor_internal_dtype) 79 | # N x 3 x ? x ? 80 | 81 | x = interpolate_bilinear_2d_like_tensorflow1x( 82 | x, 83 | size=(self.INPUT_IMAGE_SIZE, self.INPUT_IMAGE_SIZE), 84 | align_corners=False, 85 | ) 86 | # N x 3 x 224 x 224 87 | 88 | x = torchvision.transforms.functional.normalize( 89 | x, 90 | (255 * 0.485, 255 * 0.456, 255 * 0.406), 91 | (255 * 0.229, 255 * 0.224, 255 * 0.225), 92 | inplace=False, 93 | ) 94 | # N x 3 x 224 x 224 95 | 96 | x = self.model(x) 97 | 98 | if "fc2" in remaining_features: 99 | features["fc2"] = x.to(torch.float32) 100 | remaining_features.remove("fc2") 101 | if len(remaining_features) == 0: 102 | return tuple(features[a] for a in self.features_list) 103 | 104 | features["fc2_relu"] = F.relu(x).to(torch.float32) 105 | 106 | return tuple(features[a] for a in self.features_list) 107 | 108 | @staticmethod 109 | def get_provided_features_list(): 110 | return "fc2", "fc2_relu" 111 | 112 | @staticmethod 113 | def get_default_feature_layer_for_metric(metric): 114 | return { 115 | "isc": "fc2_relu", 116 | "fid": "fc2_relu", 117 | "kid": "fc2_relu", 118 | "prc": "fc2_relu", 119 | }[metric] 120 | 121 | @staticmethod 122 | def can_be_compiled(): 123 | return True 124 | 125 | @staticmethod 126 | def get_dummy_input_for_compile(): 127 | return (torch.rand([1, 3, 4, 4]) * 255).to(torch.uint8) 128 | -------------------------------------------------------------------------------- /torch_fidelity/generative_model_base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import torch 4 | 5 | 6 | class GenerativeModelBase(ABC, torch.nn.Module): 7 | """ 8 | Base class for generative models that can be used as inputs in :func:`calculate_metrics`. 9 | """ 10 | 11 | @property 12 | @abstractmethod 13 | def z_size(self): 14 | """ 15 | Size of the noise dimension of the generative model (positive integer). 16 | """ 17 | pass 18 | 19 | @property 20 | @abstractmethod 21 | def z_type(self): 22 | """ 23 | Type of the noise used by the generative model (see :ref:`registry ` for a list of preregistered noise 24 | types, see :func:`register_noise_source` for registering a new noise type). 25 | """ 26 | pass 27 | 28 | @property 29 | @abstractmethod 30 | def num_classes(self): 31 | """ 32 | Number of classes used by a conditional generative model. Must return zero for unconditional models. 33 | """ 34 | pass 35 | -------------------------------------------------------------------------------- /torch_fidelity/generative_model_modulewrapper.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import torch 4 | 5 | from torch_fidelity.generative_model_base import GenerativeModelBase 6 | from torch_fidelity.helpers import vassert 7 | 8 | 9 | class GenerativeModelModuleWrapper(GenerativeModelBase): 10 | def __init__(self, module, z_size, z_type, num_classes, make_copy=False, make_eval=True, cuda=None): 11 | """ 12 | Wraps any generative model :class:`torch.nn.Module`, implements the :class:`GenerativeModelBase` interface, and 13 | provides a few convenience functions. 14 | 15 | Args: 16 | 17 | module (torch.nn.Module): A generative model module, taking a batch of noise samples, and producing 18 | generative samples. 19 | 20 | z_size (int): Size of the noise dimension of the generative model (positive integer). 21 | 22 | z_type (str): Type of the noise used by the generative model (see :ref:`registry ` for a list of 23 | preregistered noise types, see :func:`register_noise_source` for registering a new noise type). 24 | 25 | num_classes (int): Number of classes used by a conditional generative model. Must return zero for 26 | unconditional models. 27 | 28 | make_copy (bool): Makes a copy of the model weights if `True`. Default: `False`. 29 | 30 | make_eval (bool): Switches to :class:`torch.nn.Module` evaluation mode upon construction if `True`. Default: 31 | `True`. 32 | 33 | cuda (bool): Moves the module on a CUDA device if `True`, moves to CPU if `False`, does nothing if `None`. 34 | Default: `None`. 35 | """ 36 | super().__init__() 37 | vassert(isinstance(module, torch.nn.Module), "Not an instance of torch.nn.Module") 38 | vassert(type(z_size) is int and z_size > 0, "z_size must be a positive integer") 39 | vassert(z_type in ("normal", "unit", "uniform_0_1"), f"z_type={z_type} not implemented") 40 | vassert(type(num_classes) is int and num_classes >= 0, "num_classes must be a non-negative integer") 41 | self.module = module 42 | if make_copy: 43 | self.module = copy.deepcopy(self.module) 44 | if make_eval: 45 | self.module.eval() 46 | if cuda is not None: 47 | if cuda: 48 | self.module = self.module.cuda() 49 | else: 50 | self.module = self.module.cpu() 51 | self._z_size = z_size 52 | self._z_type = z_type 53 | self._num_classes = num_classes 54 | 55 | @property 56 | def z_size(self): 57 | return self._z_size 58 | 59 | @property 60 | def z_type(self): 61 | return self._z_type 62 | 63 | @property 64 | def num_classes(self): 65 | return self._num_classes 66 | 67 | def forward(self, *args, **kwargs): 68 | return self.module.forward(*args, **kwargs) 69 | -------------------------------------------------------------------------------- /torch_fidelity/generative_model_onnx.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from torch_fidelity.generative_model_base import GenerativeModelBase 7 | from torch_fidelity.helpers import vassert 8 | 9 | 10 | class GenerativeModelONNX(GenerativeModelBase): 11 | def __init__(self, path_onnx, z_size, z_type, num_classes): 12 | """ 13 | Wraps :obj:`ONNX` generative model, implements the :class:`GenerativeModelBase` interface. 14 | 15 | Args: 16 | 17 | path_onnx (str): Path to a generative model in :obj:`ONNX` format. 18 | 19 | z_size (int): Size of the noise dimension of the generative model (positive integer). 20 | 21 | z_type (str): Type of the noise used by the generative model (see :ref:`registry ` for a list of 22 | preregistered noise types, see :func:`register_noise_source` for registering a new noise type). 23 | 24 | num_classes (int): Number of classes used by a conditional generative model. Must return zero for 25 | unconditional models. 26 | """ 27 | super().__init__() 28 | vassert(os.path.isfile(path_onnx), f'Model file not found at "{path_onnx}"') 29 | vassert(type(z_size) is int and z_size > 0, "z_size must be a positive integer") 30 | vassert(z_type in ("normal", "unit", "uniform_0_1"), f"z_type={z_type} not implemented") 31 | vassert(type(num_classes) is int and num_classes >= 0, "num_classes must be a non-negative integer") 32 | try: 33 | import onnxruntime 34 | except ImportError as e: 35 | # This message may be removed if onnxruntime becomes a unified package with embedded CUDA dependencies, 36 | # like for example pytorch 37 | print( 38 | "====================================================================================================\n" 39 | "Loading ONNX models in PyTorch requires ONNX runtime package, which we did not want to include in\n" 40 | "torch_fidelity package requirements.txt. The two relevant pip packages are:\n" 41 | " - onnxruntime (pip install onnxruntime), or\n" 42 | " - onnxruntime-gpu (pip install onnxruntime-gpu).\n" 43 | 'If you choose to install "onnxruntime", you will be able to run inference on CPU only - this may be\n' 44 | 'slow. With "onnxruntime-gpu" speed is not an issue, but at run time you might face CUDA toolkit\n' 45 | "versions incompatibility, which can only be resolved by recompiling onnxruntime-gpu from source.\n" 46 | "Alternatively, use calculate_metrics API and pass an instance of GenerativeModelBase as an input.\n" 47 | "====================================================================================================" 48 | ) 49 | raise e 50 | self.ort_session = onnxruntime.InferenceSession(path_onnx) 51 | self.input_names = [a.name for a in self.ort_session.get_inputs()] 52 | self._z_size = z_size 53 | self._z_type = z_type 54 | self._num_classes = num_classes 55 | 56 | @property 57 | def z_size(self): 58 | return self._z_size 59 | 60 | @property 61 | def z_type(self): 62 | return self._z_type 63 | 64 | @property 65 | def num_classes(self): 66 | return self._num_classes 67 | 68 | @staticmethod 69 | def to_numpy(tensor): 70 | return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() 71 | 72 | def forward(self, *args): 73 | vassert( 74 | len(args) == len(self.input_names), 75 | f"Number of input arguments {len(args)} does not match ONNX model: {self.input_names}", 76 | ) 77 | vassert(all(torch.is_tensor(a) for a in args), "All model inputs must be tensors") 78 | ort_input = {self.input_names[i]: self.to_numpy(args[i]) for i in range(len(args))} 79 | ort_output = self.ort_session.run(None, ort_input) 80 | ort_output = ort_output[0] 81 | vassert(isinstance(ort_output, np.ndarray), "Invalid output of ONNX model") 82 | out = torch.from_numpy(ort_output).to(device=args[0].device) 83 | return out 84 | -------------------------------------------------------------------------------- /torch_fidelity/helpers.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | import warnings 4 | 5 | import torch 6 | 7 | from torch_fidelity.defaults import DEFAULTS 8 | from torch_fidelity.deprecations import DEPRECATIONS 9 | 10 | 11 | def vassert(truecond, message): 12 | if not truecond: 13 | raise ValueError(message) 14 | 15 | 16 | def vprint(verbose, message): 17 | if verbose: 18 | print(message, file=sys.stderr) 19 | 20 | 21 | def get_kwarg(name, kwargs): 22 | return kwargs.get(name, DEFAULTS[name]) 23 | 24 | 25 | def json_decode_string(s): 26 | try: 27 | out = json.loads(s) 28 | except json.JSONDecodeError as e: 29 | print(f"Failed to decode JSON string: {s}", file=sys.stderr) 30 | raise 31 | return out 32 | 33 | 34 | def text_to_dtype(name, default=None): 35 | DTYPES = { 36 | "uint8": torch.uint8, 37 | "float32": torch.float32, 38 | "float64": torch.float32, 39 | } 40 | if default in DTYPES: 41 | default = DTYPES[default] 42 | return DTYPES.get(name, default) 43 | 44 | 45 | class CleanStderr: 46 | def __init__(self, filter_phrases, stream=sys.stderr): 47 | self.filter_phrases = filter_phrases 48 | self.stream = stream 49 | 50 | def __enter__(self): 51 | sys.stderr = self 52 | 53 | def __exit__(self, exc_type, exc_value, traceback): 54 | sys.stderr = self.stream 55 | 56 | def write(self, msg): 57 | if not any(phrase in msg for phrase in self.filter_phrases): 58 | self.stream.write(msg) 59 | 60 | def flush(self): 61 | self.stream.flush() 62 | 63 | 64 | def process_deprecations(cfg): 65 | for k, v in cfg.items(): 66 | if k not in DEPRECATIONS: 67 | continue 68 | depr = DEPRECATIONS[k] 69 | new_k = depr["new_name"] 70 | cfg[new_k] = v 71 | warnings.warn( 72 | f"Argument \"{k}\" is deprecated since {depr['since']}; use \"{new_k}\" instead. Reason: {depr['reason']}", 73 | FutureWarning, 74 | stacklevel=2, 75 | ) 76 | del cfg[k] 77 | -------------------------------------------------------------------------------- /torch_fidelity/interpolate_compat_tensorflow.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch.nn.modules.utils import _ntuple 6 | 7 | 8 | def interpolate_bilinear_2d_like_tensorflow1x(input, size=None, scale_factor=None, align_corners=None, method="slow"): 9 | r"""Down/up samples the input to either the given :attr:`size` or the given :attr:`scale_factor` 10 | 11 | Epsilon-exact bilinear interpolation as it is implemented in TensorFlow 1.x: 12 | https://github.com/tensorflow/tensorflow/blob/f66daa493e7383052b2b44def2933f61faf196e0/tensorflow/core/kernels/image_resizer_state.h#L41 13 | https://github.com/tensorflow/tensorflow/blob/6795a8c3a3678fb805b6a8ba806af77ddfe61628/tensorflow/core/kernels/resize_bilinear_op.cc#L85 14 | as per proposal: 15 | https://github.com/pytorch/pytorch/issues/10604#issuecomment-465783319 16 | 17 | Related materials: 18 | https://hackernoon.com/how-tensorflows-tf-image-resize-stole-60-days-of-my-life-aba5eb093f35 19 | https://jricheimer.github.io/tensorflow/2019/02/11/resize-confusion/ 20 | https://machinethink.net/blog/coreml-upsampling/ 21 | 22 | Currently only 2D spatial sampling is supported, i.e. expected inputs are 4-D in shape. 23 | 24 | The input dimensions are interpreted in the form: 25 | `mini-batch x channels x height x width`. 26 | 27 | Args: 28 | input (Tensor): the input tensor 29 | size (Tuple[int, int]): output spatial size. 30 | scale_factor (float or Tuple[float]): multiplier for spatial size. Has to match input size if it is a tuple. 31 | align_corners (bool, optional): Same meaning as in TensorFlow 1.x. 32 | method (str, optional): 33 | 'slow' (1e-4 L_inf error on GPU, bit-exact on CPU, with checkerboard 32x32->299x299), or 34 | 'fast' (1e-3 L_inf error on GPU and CPU, with checkerboard 32x32->299x299) 35 | """ 36 | if method not in ("slow", "fast"): 37 | raise ValueError('how_exact can only be one of "slow", "fast"') 38 | 39 | if input.dim() != 4: 40 | raise ValueError("input must be a 4-D tensor") 41 | 42 | if not torch.is_floating_point(input): 43 | raise ValueError("input must be of floating point dtype") 44 | 45 | if size is not None and (type(size) not in (tuple, list) or len(size) != 2): 46 | raise ValueError("size must be a list or a tuple of two elements") 47 | 48 | if align_corners is None: 49 | raise ValueError("align_corners is not specified (use this function for a complete determinism)") 50 | 51 | def _check_size_scale_factor(dim): 52 | if size is None and scale_factor is None: 53 | raise ValueError("either size or scale_factor should be defined") 54 | if size is not None and scale_factor is not None: 55 | raise ValueError("only one of size or scale_factor should be defined") 56 | if scale_factor is not None and isinstance(scale_factor, tuple) and len(scale_factor) != dim: 57 | raise ValueError( 58 | "scale_factor shape must match input shape. " 59 | "Input is {}D, scale_factor size is {}".format(dim, len(scale_factor)) 60 | ) 61 | 62 | is_tracing = torch._C._get_tracing_state() 63 | 64 | def _output_size(dim): 65 | _check_size_scale_factor(dim) 66 | if size is not None: 67 | if is_tracing: 68 | return [torch.tensor(i) for i in size] 69 | else: 70 | return size 71 | scale_factors = _ntuple(dim)(scale_factor) 72 | # math.floor might return float in py2.7 73 | 74 | # make scale_factor a tensor in tracing so constant doesn't get baked in 75 | if is_tracing: 76 | return [ 77 | (torch.floor((input.size(i + 2).float() * torch.tensor(scale_factors[i], dtype=torch.float32)).float())) 78 | for i in range(dim) 79 | ] 80 | else: 81 | return [int(math.floor(float(input.size(i + 2)) * scale_factors[i])) for i in range(dim)] 82 | 83 | def tf_calculate_resize_scale(in_size, out_size): 84 | if align_corners: 85 | if is_tracing: 86 | return (in_size - 1) / (out_size.float() - 1).clamp(min=1) 87 | else: 88 | return (in_size - 1) / max(1, out_size - 1) 89 | else: 90 | if is_tracing: 91 | return in_size / out_size.float() 92 | else: 93 | return in_size / out_size 94 | 95 | out_size = _output_size(2) 96 | scale_x = tf_calculate_resize_scale(input.shape[3], out_size[1]) 97 | scale_y = tf_calculate_resize_scale(input.shape[2], out_size[0]) 98 | 99 | def resample_using_grid_sample(): 100 | grid_x = torch.arange(0, out_size[1], 1, dtype=input.dtype, device=input.device) 101 | grid_x = grid_x * (2 * scale_x / (input.shape[3] - 1)) - 1 102 | 103 | grid_y = torch.arange(0, out_size[0], 1, dtype=input.dtype, device=input.device) 104 | grid_y = grid_y * (2 * scale_y / (input.shape[2] - 1)) - 1 105 | 106 | grid_x = grid_x.view(1, out_size[1]).repeat(out_size[0], 1) 107 | grid_y = grid_y.view(out_size[0], 1).repeat(1, out_size[1]) 108 | 109 | grid_xy = torch.cat((grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)), dim=2).unsqueeze(0) 110 | grid_xy = grid_xy.repeat(input.shape[0], 1, 1, 1) 111 | 112 | out = F.grid_sample(input, grid_xy, mode="bilinear", padding_mode="border", align_corners=True) 113 | return out 114 | 115 | def resample_manually(): 116 | grid_x = torch.arange(0, out_size[1], 1, dtype=input.dtype, device=input.device) 117 | grid_x = grid_x * torch.tensor(scale_x, dtype=torch.float32) 118 | grid_x_lo = grid_x.long() 119 | grid_x_hi = (grid_x_lo + 1).clamp_max(input.shape[3] - 1) 120 | grid_dx = grid_x - grid_x_lo.float() 121 | 122 | grid_y = torch.arange(0, out_size[0], 1, dtype=input.dtype, device=input.device) 123 | grid_y = grid_y * torch.tensor(scale_y, dtype=torch.float32) 124 | grid_y_lo = grid_y.long() 125 | grid_y_hi = (grid_y_lo + 1).clamp_max(input.shape[2] - 1) 126 | grid_dy = grid_y - grid_y_lo.float() 127 | 128 | # could be improved with index_select 129 | in_00 = input[:, :, grid_y_lo, :][:, :, :, grid_x_lo] 130 | in_01 = input[:, :, grid_y_lo, :][:, :, :, grid_x_hi] 131 | in_10 = input[:, :, grid_y_hi, :][:, :, :, grid_x_lo] 132 | in_11 = input[:, :, grid_y_hi, :][:, :, :, grid_x_hi] 133 | 134 | in_0 = in_00 + (in_01 - in_00) * grid_dx.view(1, 1, 1, out_size[1]) 135 | in_1 = in_10 + (in_11 - in_10) * grid_dx.view(1, 1, 1, out_size[1]) 136 | out = in_0 + (in_1 - in_0) * grid_dy.view(1, 1, out_size[0], 1) 137 | 138 | return out 139 | 140 | if method == "slow": 141 | out = resample_manually() 142 | else: 143 | out = resample_using_grid_sample() 144 | 145 | return out 146 | -------------------------------------------------------------------------------- /torch_fidelity/metric_fid.py: -------------------------------------------------------------------------------- 1 | # Functions fid_features_to_statistics and fid_statistics_to_metric are adapted from 2 | # https://github.com/bioinf-jku/TTUR/blob/master/fid.py commit id d4baae8 3 | # Distributed under Apache License 2.0: https://github.com/bioinf-jku/TTUR/blob/master/LICENSE 4 | 5 | import numpy as np 6 | import torch 7 | 8 | from torch_fidelity.helpers import get_kwarg, vprint 9 | from torch_fidelity.utils import ( 10 | get_cacheable_input_name, 11 | cache_lookup_one_recompute_on_miss, 12 | extract_featuresdict_from_input_id_cached, 13 | create_feature_extractor, 14 | resolve_feature_extractor, 15 | resolve_feature_layer_for_metric, 16 | ) 17 | 18 | KEY_METRIC_FID = "frechet_inception_distance" 19 | 20 | 21 | def fid_features_to_statistics(features): 22 | assert torch.is_tensor(features) and features.dim() == 2 23 | features = features.numpy() 24 | mu = np.mean(features, axis=0) 25 | sigma = np.cov(features, rowvar=False) 26 | return { 27 | "mu": mu, 28 | "sigma": sigma, 29 | } 30 | 31 | 32 | def fid_statistics_to_metric(stat_1, stat_2, verbose): 33 | mu1, sigma1 = stat_1["mu"], stat_1["sigma"] 34 | mu2, sigma2 = stat_2["mu"], stat_2["sigma"] 35 | assert mu1.ndim == 1 and mu1.shape == mu2.shape and mu1.dtype == mu2.dtype 36 | assert sigma1.ndim == 2 and sigma1.shape == sigma2.shape and sigma1.dtype == sigma2.dtype 37 | 38 | diff = mu1 - mu2 39 | tr_covmean = np.sum(np.sqrt(np.linalg.eigvals(sigma1.dot(sigma2)).astype("complex128")).real) 40 | fid = float(diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean) 41 | 42 | out = {KEY_METRIC_FID: fid} 43 | 44 | vprint(verbose, f"Frechet Inception Distance: {out[KEY_METRIC_FID]:.7g}") 45 | 46 | return out 47 | 48 | 49 | def fid_featuresdict_to_statistics(featuresdict, feat_layer_name): 50 | features = featuresdict[feat_layer_name] 51 | statistics = fid_features_to_statistics(features) 52 | return statistics 53 | 54 | 55 | def fid_featuresdict_to_statistics_cached( 56 | featuresdict, cacheable_input_name, feat_extractor, feat_layer_name, **kwargs 57 | ): 58 | def fn_recompute(): 59 | return fid_featuresdict_to_statistics(featuresdict, feat_layer_name) 60 | 61 | if cacheable_input_name is not None: 62 | feat_extractor_name = feat_extractor.get_name() 63 | cached_name = f"{cacheable_input_name}-{feat_extractor_name}-stat-fid-{feat_layer_name}" 64 | stat = cache_lookup_one_recompute_on_miss(cached_name, fn_recompute, **kwargs) 65 | else: 66 | stat = fn_recompute() 67 | return stat 68 | 69 | 70 | def fid_input_id_to_statistics(input_id, feat_extractor, feat_layer_name, **kwargs): 71 | featuresdict = extract_featuresdict_from_input_id_cached(input_id, feat_extractor, **kwargs) 72 | return fid_featuresdict_to_statistics(featuresdict, feat_layer_name) 73 | 74 | 75 | def fid_input_id_to_statistics_cached(input_id, feat_extractor, feat_layer_name, **kwargs): 76 | def fn_recompute(): 77 | return fid_input_id_to_statistics(input_id, feat_extractor, feat_layer_name, **kwargs) 78 | 79 | cacheable_input_name = get_cacheable_input_name(input_id, **kwargs) 80 | 81 | if cacheable_input_name is not None: 82 | feat_extractor_name = feat_extractor.get_name() 83 | cached_name = f"{cacheable_input_name}-{feat_extractor_name}-stat-fid-{feat_layer_name}" 84 | stat = cache_lookup_one_recompute_on_miss(cached_name, fn_recompute, **kwargs) 85 | else: 86 | stat = fn_recompute() 87 | return stat 88 | 89 | 90 | def fid_inputs_to_metric(feat_extractor, **kwargs): 91 | feat_layer_name = resolve_feature_layer_for_metric("fid", **kwargs) 92 | verbose = get_kwarg("verbose", kwargs) 93 | 94 | vprint(verbose, f"Extracting statistics from input 1") 95 | stats_1 = fid_input_id_to_statistics_cached(1, feat_extractor, feat_layer_name, **kwargs) 96 | 97 | vprint(verbose, f"Extracting statistics from input 2") 98 | stats_2 = fid_input_id_to_statistics_cached(2, feat_extractor, feat_layer_name, **kwargs) 99 | 100 | metric = fid_statistics_to_metric(stats_1, stats_2, get_kwarg("verbose", kwargs)) 101 | return metric 102 | 103 | 104 | def calculate_fid(**kwargs): 105 | kwargs["fid"] = True 106 | feature_extractor = resolve_feature_extractor(**kwargs) 107 | feat_layer_name = resolve_feature_layer_for_metric("fid", **kwargs) 108 | feat_extractor = create_feature_extractor(feature_extractor, [feat_layer_name], **kwargs) 109 | metric = fid_inputs_to_metric(feat_extractor, **kwargs) 110 | return metric 111 | -------------------------------------------------------------------------------- /torch_fidelity/metric_isc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from torch_fidelity.helpers import get_kwarg, vprint 5 | from torch_fidelity.utils import ( 6 | extract_featuresdict_from_input_id_cached, 7 | create_feature_extractor, 8 | resolve_feature_extractor, 9 | resolve_feature_layer_for_metric, 10 | ) 11 | 12 | KEY_METRIC_ISC_MEAN = "inception_score_mean" 13 | KEY_METRIC_ISC_STD = "inception_score_std" 14 | 15 | 16 | def isc_features_to_metric(feature, splits=10, shuffle=True, rng_seed=2020): 17 | assert torch.is_tensor(feature) and feature.dim() == 2 18 | N, C = feature.shape 19 | if shuffle: 20 | rng = np.random.RandomState(rng_seed) 21 | feature = feature[rng.permutation(N), :] 22 | feature = feature.double() 23 | 24 | p = feature.softmax(dim=1) 25 | log_p = feature.log_softmax(dim=1) 26 | 27 | scores = [] 28 | for i in range(splits): 29 | p_chunk = p[(i * N // splits) : ((i + 1) * N // splits), :] 30 | log_p_chunk = log_p[(i * N // splits) : ((i + 1) * N // splits), :] 31 | q_chunk = p_chunk.mean(dim=0, keepdim=True) 32 | kl = p_chunk * (log_p_chunk - q_chunk.log()) 33 | kl = kl.sum(dim=1).mean().exp().item() 34 | scores.append(kl) 35 | 36 | return { 37 | KEY_METRIC_ISC_MEAN: float(np.mean(scores)), 38 | KEY_METRIC_ISC_STD: float(np.std(scores)), 39 | } 40 | 41 | 42 | def isc_featuresdict_to_metric(featuresdict, feat_layer_name, **kwargs): 43 | features = featuresdict[feat_layer_name] 44 | 45 | out = isc_features_to_metric( 46 | features, 47 | get_kwarg("isc_splits", kwargs), 48 | get_kwarg("samples_shuffle", kwargs), 49 | get_kwarg("rng_seed", kwargs), 50 | ) 51 | 52 | vprint( 53 | get_kwarg("verbose", kwargs), f"Inception Score: {out[KEY_METRIC_ISC_MEAN]:.7g} ± {out[KEY_METRIC_ISC_STD]:.7g}" 54 | ) 55 | 56 | return out 57 | 58 | 59 | def isc_input_id_to_metric(input_id, feat_extractor, feat_layer_name, **kwargs): 60 | featuresdict = extract_featuresdict_from_input_id_cached(input_id, feat_extractor, **kwargs) 61 | return isc_featuresdict_to_metric(featuresdict, feat_layer_name, **kwargs) 62 | 63 | 64 | def calculate_isc(input_id, **kwargs): 65 | kwargs["isc"] = True 66 | feature_extractor = resolve_feature_extractor(**kwargs) 67 | feat_layer_name = resolve_feature_layer_for_metric("isc", **kwargs) 68 | feat_extractor = create_feature_extractor(feature_extractor, [feat_layer_name], **kwargs) 69 | metric = isc_input_id_to_metric(input_id, feat_extractor, feat_layer_name, **kwargs) 70 | return metric 71 | -------------------------------------------------------------------------------- /torch_fidelity/metric_kid.py: -------------------------------------------------------------------------------- 1 | # Functions mmd2 and polynomial_kernel are adapted from 2 | # https://github.com/mbinkowski/MMD-GAN/blob/master/gan/compute_scores.py 3 | # Distributed under BSD 3-Clause: https://github.com/mbinkowski/MMD-GAN/blob/master/LICENSE 4 | 5 | import numpy as np 6 | import torch 7 | from tqdm import tqdm 8 | 9 | from torch_fidelity.helpers import get_kwarg, vassert, vprint 10 | from torch_fidelity.utils import ( 11 | create_feature_extractor, 12 | extract_featuresdict_from_input_id_cached, 13 | resolve_feature_extractor, 14 | resolve_feature_layer_for_metric, 15 | ) 16 | 17 | KEY_METRIC_KID_MEAN = "kernel_inception_distance_mean" 18 | KEY_METRIC_KID_STD = "kernel_inception_distance_std" 19 | 20 | 21 | def mmd2(K_XX, K_XY, K_YY, unit_diagonal=False, mmd_est="unbiased"): 22 | vassert(mmd_est in ("biased", "unbiased", "u-statistic"), "Invalid value of mmd_est") 23 | 24 | m = K_XX.shape[0] 25 | assert K_XX.shape == (m, m) 26 | assert K_XY.shape == (m, m) 27 | assert K_YY.shape == (m, m) 28 | 29 | # Get the various sums of kernels that we'll use 30 | # Kts drop the diagonal, but we don't need to compute them explicitly 31 | if unit_diagonal: 32 | diag_X = diag_Y = 1 33 | sum_diag_X = sum_diag_Y = m 34 | else: 35 | diag_X = np.diagonal(K_XX) 36 | diag_Y = np.diagonal(K_YY) 37 | 38 | sum_diag_X = diag_X.sum() 39 | sum_diag_Y = diag_Y.sum() 40 | 41 | Kt_XX_sums = K_XX.sum(axis=1) - diag_X 42 | Kt_YY_sums = K_YY.sum(axis=1) - diag_Y 43 | K_XY_sums_0 = K_XY.sum(axis=0) 44 | 45 | Kt_XX_sum = Kt_XX_sums.sum() 46 | Kt_YY_sum = Kt_YY_sums.sum() 47 | K_XY_sum = K_XY_sums_0.sum() 48 | 49 | if mmd_est == "biased": 50 | mmd2 = (Kt_XX_sum + sum_diag_X) / (m * m) \ 51 | + (Kt_YY_sum + sum_diag_Y) / (m * m) \ 52 | - 2 * K_XY_sum / (m * m) # fmt: skip 53 | else: 54 | mmd2 = (Kt_XX_sum + Kt_YY_sum) / (m * (m - 1)) 55 | if mmd_est == "unbiased": 56 | mmd2 -= 2 * K_XY_sum / (m * m) 57 | else: 58 | mmd2 -= 2 * (K_XY_sum - np.trace(K_XY)) / (m * (m - 1)) 59 | 60 | return mmd2 61 | 62 | 63 | def kernel_poly(X, Y, **kwargs): 64 | degree = get_kwarg("kid_kernel_poly_degree", kwargs) 65 | gamma = get_kwarg("kid_kernel_poly_gamma", kwargs) 66 | coef0 = get_kwarg("kid_kernel_poly_coef0", kwargs) 67 | if gamma is None: 68 | gamma = 1.0 / X.shape[1] 69 | K = (np.matmul(X, Y.T) * gamma + coef0) ** degree 70 | return K 71 | 72 | 73 | def kernel_rbf(X, Y, **kwargs): 74 | sigma = get_kwarg("kid_kernel_rbf_sigma", kwargs) 75 | vassert(sigma is not None and sigma > 0, "kid_kernel_rbf_sigma must be positive") 76 | XX = np.sum(X**2, axis=1) 77 | YY = np.sum(Y**2, axis=1) 78 | XY = np.dot(X, Y.T) 79 | K = np.exp((2 * XY - np.outer(XX, np.ones(YY.shape[0])) - np.outer(np.ones(XX.shape[0]), YY)) / (2 * sigma**2)) 80 | return K 81 | 82 | 83 | def kernel_mmd(features_1, features_2, **kwargs): 84 | kernel = get_kwarg("kid_kernel", kwargs) 85 | vassert(kernel in ("poly", "rbf"), "Invalid KID kernel") 86 | kernel = { 87 | "poly": kernel_poly, 88 | "rbf": kernel_rbf, 89 | }[kernel] 90 | k_11 = kernel(features_1, features_1, **kwargs) 91 | k_22 = kernel(features_2, features_2, **kwargs) 92 | k_12 = kernel(features_1, features_2, **kwargs) 93 | return mmd2(k_11, k_12, k_22) 94 | 95 | 96 | def kid_features_to_metric(features_1, features_2, **kwargs): 97 | assert torch.is_tensor(features_1) and features_1.dim() == 2 98 | assert torch.is_tensor(features_2) and features_2.dim() == 2 99 | assert features_1.shape[1] == features_2.shape[1] 100 | 101 | kid_subsets = get_kwarg("kid_subsets", kwargs) 102 | kid_subset_size = get_kwarg("kid_subset_size", kwargs) 103 | verbose = get_kwarg("verbose", kwargs) 104 | 105 | n_samples_1, n_samples_2 = len(features_1), len(features_2) 106 | vassert( 107 | n_samples_1 >= kid_subset_size and n_samples_2 >= kid_subset_size, 108 | f"KID subset size {kid_subset_size} cannot be smaller than the number of samples (input_1: {n_samples_1}, " 109 | f'input_2: {n_samples_2}). Consider using "kid_subset_size" kwarg or "--kid-subset-size" command line key to ' 110 | f"proceed.", 111 | ) 112 | 113 | features_1 = features_1.cpu().numpy() 114 | features_2 = features_2.cpu().numpy() 115 | 116 | mmds = np.zeros(kid_subsets) 117 | rng = np.random.RandomState(get_kwarg("rng_seed", kwargs)) 118 | 119 | for i in tqdm( 120 | range(kid_subsets), disable=not verbose, leave=False, unit="subsets", desc="Kernel Inception Distance" 121 | ): 122 | f1 = features_1[rng.choice(n_samples_1, kid_subset_size, replace=False)] 123 | f2 = features_2[rng.choice(n_samples_2, kid_subset_size, replace=False)] 124 | o = kernel_mmd(f1, f2, **kwargs) 125 | mmds[i] = o 126 | 127 | out = { 128 | KEY_METRIC_KID_MEAN: float(np.mean(mmds)), 129 | KEY_METRIC_KID_STD: float(np.std(mmds)), 130 | } 131 | 132 | vprint(verbose, f"Kernel Inception Distance: {out[KEY_METRIC_KID_MEAN]:.7g} ± {out[KEY_METRIC_KID_STD]:.7g}") 133 | 134 | return out 135 | 136 | 137 | def kid_featuresdict_to_metric(featuresdict_1, featuresdict_2, feat_layer_name, **kwargs): 138 | features_1 = featuresdict_1[feat_layer_name] 139 | features_2 = featuresdict_2[feat_layer_name] 140 | metric = kid_features_to_metric(features_1, features_2, **kwargs) 141 | if metric[KEY_METRIC_KID_MEAN] < 0 and get_kwarg("verbose", kwargs): 142 | print("KID values slightly less than 0 are valid and indicate that distributions are very similar") 143 | return metric 144 | 145 | 146 | def calculate_kid(**kwargs): 147 | kwargs["kid"] = True 148 | feature_extractor = resolve_feature_extractor(**kwargs) 149 | feat_layer_name = resolve_feature_layer_for_metric("kid", **kwargs) 150 | feat_extractor = create_feature_extractor(feature_extractor, [feat_layer_name], **kwargs) 151 | featuresdict_1 = extract_featuresdict_from_input_id_cached(1, feat_extractor, **kwargs) 152 | featuresdict_2 = extract_featuresdict_from_input_id_cached(2, feat_extractor, **kwargs) 153 | metric = kid_featuresdict_to_metric(featuresdict_1, featuresdict_2, feat_layer_name, **kwargs) 154 | return metric 155 | -------------------------------------------------------------------------------- /torch_fidelity/metric_ppl.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from tqdm import tqdm 4 | 5 | from torch_fidelity.generative_model_base import GenerativeModelBase 6 | from torch_fidelity.helpers import get_kwarg, vassert, vprint 7 | from torch_fidelity.utils import ( 8 | sample_random, 9 | batch_interp, 10 | create_sample_similarity, 11 | prepare_input_descriptor_from_input_id, 12 | prepare_input_from_descriptor, 13 | ) 14 | 15 | KEY_METRIC_PPL_RAW = "perceptual_path_length_raw" 16 | KEY_METRIC_PPL_MEAN = "perceptual_path_length_mean" 17 | KEY_METRIC_PPL_STD = "perceptual_path_length_std" 18 | 19 | 20 | def calculate_ppl(input_id, **kwargs): 21 | """ 22 | Inspired by https://github.com/NVlabs/stylegan/blob/master/metrics/perceptual_path_length.py 23 | """ 24 | kwargs["ppl"] = True 25 | batch_size = get_kwarg("batch_size", kwargs) 26 | cuda = get_kwarg("cuda", kwargs) 27 | verbose = get_kwarg("verbose", kwargs) 28 | epsilon = get_kwarg("ppl_epsilon", kwargs) 29 | interp = get_kwarg("ppl_z_interp_mode", kwargs) 30 | reduction = get_kwarg("ppl_reduction", kwargs) 31 | similarity_name = get_kwarg("ppl_sample_similarity", kwargs) 32 | sample_similarity_resize = get_kwarg("ppl_sample_similarity_resize", kwargs) 33 | sample_similarity_dtype = get_kwarg("ppl_sample_similarity_dtype", kwargs) 34 | discard_percentile_lower = get_kwarg("ppl_discard_percentile_lower", kwargs) 35 | discard_percentile_higher = get_kwarg("ppl_discard_percentile_higher", kwargs) 36 | 37 | input_desc = prepare_input_descriptor_from_input_id(input_id, **kwargs) 38 | model = prepare_input_from_descriptor(input_desc, **kwargs) 39 | vassert( 40 | isinstance(model, GenerativeModelBase), 41 | "Input needs to be an instance of GenerativeModelBase, which can be either passed programmatically by wrapping " 42 | "a model with GenerativeModelModuleWrapper, or via command line by specifying a path to a ONNX or PTH (JIT) " 43 | "model and a set of input1_model_* arguments", 44 | ) 45 | 46 | if cuda: 47 | model.cuda() 48 | 49 | input_model_num_samples = input_desc["input_model_num_samples"] 50 | input_model_num_classes = model.num_classes 51 | input_model_z_size = model.z_size 52 | input_model_z_type = model.z_type 53 | 54 | vassert(input_model_num_classes >= 0, "Model can be unconditional (0 classes) or conditional (positive)") 55 | vassert( 56 | type(input_model_z_size) is int and input_model_z_size > 0, 57 | 'Dimensionality of generator noise not specified ("input1_model_z_size" argument)', 58 | ) 59 | vassert(type(epsilon) is float and epsilon > 0, "Epsilon must be a small positive floating point number") 60 | vassert(type(input_model_num_samples) is int and input_model_num_samples > 0, "Number of samples must be positive") 61 | vassert(reduction in ("none", "mean"), "Reduction must be one of [none, mean]") 62 | vassert(discard_percentile_lower is None or 0 < discard_percentile_lower < 100, "Invalid percentile") 63 | vassert(discard_percentile_higher is None or 0 < discard_percentile_higher < 100, "Invalid percentile") 64 | if discard_percentile_lower is not None and discard_percentile_higher is not None: 65 | vassert(0 < discard_percentile_lower < discard_percentile_higher < 100, "Invalid percentiles") 66 | 67 | sample_similarity = create_sample_similarity( 68 | similarity_name, 69 | sample_similarity_resize=sample_similarity_resize, 70 | sample_similarity_dtype=sample_similarity_dtype, 71 | **kwargs, 72 | ) 73 | 74 | is_cond = input_desc["input_model_num_classes"] > 0 75 | 76 | rng = np.random.RandomState(get_kwarg("rng_seed", kwargs)) 77 | 78 | lat_e0 = sample_random(rng, (input_model_num_samples, input_model_z_size), input_model_z_type) 79 | lat_e1 = sample_random(rng, (input_model_num_samples, input_model_z_size), input_model_z_type) 80 | lat_e1 = batch_interp(lat_e0, lat_e1, epsilon, interp) 81 | 82 | labels = None 83 | if is_cond: 84 | labels = torch.from_numpy(rng.randint(0, input_model_num_classes, (input_model_num_samples,))) 85 | 86 | distances = [] 87 | 88 | with tqdm( 89 | disable=not verbose, leave=False, unit="samples", total=input_model_num_samples, desc="Perceptual Path Length" 90 | ) as t, torch.no_grad(): 91 | for begin_id in range(0, input_model_num_samples, batch_size): 92 | end_id = min(begin_id + batch_size, input_model_num_samples) 93 | batch_sz = end_id - begin_id 94 | 95 | batch_lat_e0 = lat_e0[begin_id:end_id] 96 | batch_lat_e1 = lat_e1[begin_id:end_id] 97 | if is_cond: 98 | batch_labels = labels[begin_id:end_id] 99 | 100 | if cuda: 101 | batch_lat_e0 = batch_lat_e0.cuda(non_blocking=True) 102 | batch_lat_e1 = batch_lat_e1.cuda(non_blocking=True) 103 | if is_cond: 104 | batch_labels = batch_labels.cuda(non_blocking=True) 105 | 106 | if is_cond: 107 | rgb_e01 = model.forward( 108 | torch.cat((batch_lat_e0, batch_lat_e1), dim=0), 109 | torch.cat((batch_labels, batch_labels), dim=0), 110 | ) 111 | else: 112 | rgb_e01 = model.forward(torch.cat((batch_lat_e0, batch_lat_e1), dim=0)) 113 | rgb_e0, rgb_e1 = rgb_e01.chunk(2) 114 | 115 | sim = sample_similarity(rgb_e0, rgb_e1) 116 | dist_lat_e01 = sim / (epsilon**2) 117 | distances.append(dist_lat_e01.cpu().numpy()) 118 | 119 | t.update(batch_sz) 120 | 121 | distances = np.concatenate(distances, axis=0) 122 | 123 | cond, lo, hi = None, None, None 124 | if discard_percentile_lower is not None: 125 | lo = np.percentile(distances, discard_percentile_lower, interpolation="lower") 126 | cond = lo <= distances 127 | if discard_percentile_higher is not None: 128 | hi = np.percentile(distances, discard_percentile_higher, interpolation="higher") 129 | cond = np.logical_and(cond, distances <= hi) 130 | if cond is not None: 131 | distances = np.extract(cond, distances) 132 | 133 | out = { 134 | KEY_METRIC_PPL_MEAN: float(np.mean(distances)), 135 | KEY_METRIC_PPL_STD: float(np.std(distances)), 136 | } 137 | if reduction == "none": 138 | out[KEY_METRIC_PPL_RAW] = distances 139 | 140 | vprint(verbose, f"Perceptual Path Length: {out[KEY_METRIC_PPL_MEAN]:.7g} ± {out[KEY_METRIC_PPL_STD]:.7g}") 141 | 142 | return out 143 | -------------------------------------------------------------------------------- /torch_fidelity/metric_prc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_fidelity.helpers import get_kwarg, vprint 4 | from torch_fidelity.utils import ( 5 | create_feature_extractor, 6 | extract_featuresdict_from_input_id_cached, 7 | resolve_feature_extractor, 8 | resolve_feature_layer_for_metric, 9 | ) 10 | 11 | KEY_METRIC_PRECISION = "precision" 12 | KEY_METRIC_RECALL = "recall" 13 | KEY_METRIC_F_SCORE = "f_score" 14 | 15 | 16 | def calc_cdist_part(features_1, features_2, batch_size=10000): 17 | dists = [] 18 | for feat2_batch in features_2.split(batch_size): 19 | dists.append(torch.cdist(features_1, feat2_batch).cpu()) 20 | return torch.cat(dists, dim=1) 21 | 22 | 23 | def calculate_precision_recall_part(features_1, features_2, neighborhood=3, batch_size=10000): 24 | # Precision 25 | dist_nn_1 = [] 26 | for feat_1_batch in features_1.split(batch_size): 27 | dist_nn_1.append(calc_cdist_part(feat_1_batch, features_1, batch_size).kthvalue(neighborhood + 1).values) 28 | dist_nn_1 = torch.cat(dist_nn_1) 29 | precision = [] 30 | for feat_2_batch in features_2.split(batch_size): 31 | dist_2_1_batch = calc_cdist_part(feat_2_batch, features_1, batch_size) 32 | precision.append((dist_2_1_batch <= dist_nn_1).any(dim=1).float()) 33 | precision = torch.cat(precision).mean().item() 34 | # Recall 35 | dist_nn_2 = [] 36 | for feat_2_batch in features_2.split(batch_size): 37 | dist_nn_2.append(calc_cdist_part(feat_2_batch, features_2, batch_size).kthvalue(neighborhood + 1).values) 38 | dist_nn_2 = torch.cat(dist_nn_2) 39 | recall = [] 40 | for feat_1_batch in features_1.split(batch_size): 41 | dist_1_2_batch = calc_cdist_part(feat_1_batch, features_2, batch_size) 42 | recall.append((dist_1_2_batch <= dist_nn_2).any(dim=1).float()) 43 | recall = torch.cat(recall).mean().item() 44 | return precision, recall 45 | 46 | 47 | def calc_cdist_full(features_1, features_2, batch_size=10000): 48 | dists = [] 49 | for feat1_batch in features_1.split(batch_size): 50 | dists_batch = [] 51 | for feat2_batch in features_2.split(batch_size): 52 | dists_batch.append(torch.cdist(feat1_batch, feat2_batch).cpu()) 53 | dists.append(torch.cat(dists_batch, dim=1)) 54 | return torch.cat(dists, dim=0) 55 | 56 | 57 | def calculate_precision_recall_full(features_1, features_2, neighborhood=3, batch_size=10000): 58 | dist_nn_1 = calc_cdist_full(features_1, features_1, batch_size).kthvalue(neighborhood + 1).values 59 | dist_nn_2 = calc_cdist_full(features_2, features_2, batch_size).kthvalue(neighborhood + 1).values 60 | dist_2_1 = calc_cdist_full(features_2, features_1, batch_size) 61 | dist_1_2 = dist_2_1.T 62 | # Precision 63 | precision = (dist_2_1 <= dist_nn_1).any(dim=1).float().mean().item() 64 | # Recall 65 | recall = (dist_1_2 <= dist_nn_2).any(dim=1).float().mean().item() 66 | return precision, recall 67 | 68 | 69 | def prc_features_to_metric(features_1, features_2, **kwargs): 70 | # Convention: features_1 is REAL, features_2 is GENERATED. This important for the notion of precision/recall only. 71 | assert torch.is_tensor(features_1) and features_1.dim() == 2 72 | assert torch.is_tensor(features_2) and features_2.dim() == 2 73 | assert features_1.shape[1] == features_2.shape[1] 74 | 75 | neighborhood = get_kwarg("prc_neighborhood", kwargs) 76 | batch_size = get_kwarg("prc_batch_size", kwargs) 77 | save_cpu_ram = get_kwarg("save_cpu_ram", kwargs) 78 | verbose = get_kwarg("verbose", kwargs) 79 | 80 | calculate_precision_recall_fn = calculate_precision_recall_part if save_cpu_ram else calculate_precision_recall_full 81 | precision, recall = calculate_precision_recall_fn(features_1, features_2, neighborhood, batch_size) 82 | f_score = 2 * precision * recall / max(1e-5, precision + recall) 83 | 84 | out = { 85 | KEY_METRIC_PRECISION: precision, 86 | KEY_METRIC_RECALL: recall, 87 | KEY_METRIC_F_SCORE: f_score, 88 | } 89 | 90 | vprint(verbose, f"Precision: {out[KEY_METRIC_PRECISION]:.7g}") 91 | vprint(verbose, f"Recall: {out[KEY_METRIC_RECALL]:.7g}") 92 | vprint(verbose, f"F-score: {out[KEY_METRIC_F_SCORE]:.7g}") 93 | 94 | return out 95 | 96 | 97 | def prc_featuresdict_to_metric(featuresdict_1, featuresdict_2, feat_layer_name, **kwargs): 98 | features_1 = featuresdict_1[feat_layer_name] 99 | features_2 = featuresdict_2[feat_layer_name] 100 | metric = prc_features_to_metric(features_1, features_2, **kwargs) 101 | return metric 102 | 103 | 104 | def calculate_prc(**kwargs): 105 | kwargs["prc"] = True 106 | feature_extractor = resolve_feature_extractor(**kwargs) 107 | feat_layer_name = resolve_feature_layer_for_metric("prc", **kwargs) 108 | feat_extractor = create_feature_extractor(feature_extractor, [feat_layer_name], **kwargs) 109 | featuresdict_1 = extract_featuresdict_from_input_id_cached(1, feat_extractor, **kwargs) 110 | featuresdict_2 = extract_featuresdict_from_input_id_cached(2, feat_extractor, **kwargs) 111 | metric = prc_featuresdict_to_metric(featuresdict_1, featuresdict_2, feat_layer_name, **kwargs) 112 | return metric 113 | -------------------------------------------------------------------------------- /torch_fidelity/noise.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def batch_normalize_last_dim(v, eps=1e-7): 5 | return v / (v**2).sum(dim=-1, keepdim=True).sqrt().clamp_min(eps) 6 | 7 | 8 | def random_normal(rng, shape): 9 | return torch.from_numpy(rng.randn(*shape)).float() 10 | 11 | 12 | def random_unit(rng, shape): 13 | return batch_normalize_last_dim(torch.from_numpy(rng.rand(*shape)).float()) 14 | 15 | 16 | def random_uniform_0_1(rng, shape): 17 | return torch.from_numpy(rng.rand(*shape)).float() 18 | 19 | 20 | def batch_lerp(a, b, t): 21 | return a + (b - a) * t 22 | 23 | 24 | def batch_slerp_any(a, b, t, eps=1e-7): 25 | assert torch.is_tensor(a) and torch.is_tensor(b) and a.dim() >= 2 and a.shape == b.shape 26 | ndims, N = a.dim() - 1, a.shape[-1] 27 | a_1 = batch_normalize_last_dim(a, eps) 28 | b_1 = batch_normalize_last_dim(b, eps) 29 | d = (a_1 * b_1).sum(dim=-1, keepdim=True) 30 | mask_zero = (a_1.norm(dim=-1, keepdim=True) < eps) | (b_1.norm(dim=-1, keepdim=True) < eps) 31 | mask_collinear = (d > 1 - eps) | (d < -1 + eps) 32 | mask_lerp = (mask_zero | mask_collinear).repeat([1 for _ in range(ndims)] + [N]) 33 | omega = d.acos() 34 | denom = omega.sin().clamp_min(eps) 35 | coef_a = ((1 - t) * omega).sin() / denom 36 | coef_b = (t * omega).sin() / denom 37 | out = coef_a * a + coef_b * b 38 | out[mask_lerp] = batch_lerp(a, b, t)[mask_lerp] 39 | return out 40 | 41 | 42 | def batch_slerp_unit(a, b, t, eps=1e-7): 43 | out = batch_slerp_any(a, b, t, eps) 44 | out = batch_normalize_last_dim(out, eps) 45 | return out 46 | -------------------------------------------------------------------------------- /torch_fidelity/registry.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torch_fidelity.datasets import TransformPILtoRGBTensor, Cifar10_RGB, Cifar100_RGB, STL10_RGB 4 | from torch_fidelity.feature_extractor_base import FeatureExtractorBase 5 | from torch_fidelity.feature_extractor_clip import FeatureExtractorCLIP 6 | from torch_fidelity.feature_extractor_dinov2 import FeatureExtractorDinoV2 7 | from torch_fidelity.feature_extractor_inceptionv3 import FeatureExtractorInceptionV3 8 | from torch_fidelity.feature_extractor_vgg16 import FeatureExtractorVGG16 9 | from torch_fidelity.helpers import vassert 10 | from torch_fidelity.noise import ( 11 | random_normal, 12 | random_unit, 13 | random_uniform_0_1, 14 | batch_lerp, 15 | batch_slerp_any, 16 | batch_slerp_unit, 17 | ) 18 | from torch_fidelity.sample_similarity_base import SampleSimilarityBase 19 | from torch_fidelity.sample_similarity_lpips import SampleSimilarityLPIPS 20 | 21 | DATASETS_REGISTRY = dict() 22 | FEATURE_EXTRACTORS_REGISTRY = dict() 23 | SAMPLE_SIMILARITY_REGISTRY = dict() 24 | NOISE_SOURCE_REGISTRY = dict() 25 | INTERPOLATION_REGISTRY = dict() 26 | 27 | 28 | def register_dataset(name, fn_create): 29 | """ 30 | Registers a new input source. 31 | 32 | Args: 33 | 34 | name (str): Unique name of the input source. 35 | 36 | fn_create (callable): A constructor of a :class:`~torch:torch.utils.data.Dataset` instance. Callable arguments: 37 | 38 | - `root` (str): Location where the dataset files may be downloaded. 39 | - `download` (bool): Whether to perform downloading or rely on the cached version. 40 | """ 41 | vassert(type(name) is str, "Dataset must be given a name") 42 | vassert(name.strip() == name, "Name must not have leading or trailing whitespaces") 43 | vassert(os.path.sep not in name, "Name must not contain path delimiters (slash/backslash)") 44 | vassert(name not in DATASETS_REGISTRY, f'Dataset "{name}" is already registered') 45 | vassert( 46 | callable(fn_create), 47 | "Dataset must be provided as a callable (function, lambda) with 2 bool arguments: root, download", 48 | ) 49 | DATASETS_REGISTRY[name] = fn_create 50 | 51 | 52 | def register_feature_extractor(name, cls): 53 | """ 54 | Registers a new feature extractor. 55 | 56 | Args: 57 | 58 | name (str): Unique name of the feature extractor. 59 | 60 | cls (FeatureExtractorBase): Instance of :class:`FeatureExtractorBase`, implementing a new feature extractor. 61 | """ 62 | vassert(type(name) is str, "Feature extractor must be given a name") 63 | vassert(name.strip() == name, "Name must not have leading or trailing whitespaces") 64 | vassert(os.path.sep not in name, "Name must not contain path delimiters (slash/backslash)") 65 | vassert(name not in FEATURE_EXTRACTORS_REGISTRY, f'Feature extractor "{name}" is already registered') 66 | vassert( 67 | issubclass(cls, FeatureExtractorBase), "Feature extractor class must be subclassed from FeatureExtractorBase" 68 | ) 69 | FEATURE_EXTRACTORS_REGISTRY[name] = cls 70 | 71 | 72 | def register_sample_similarity(name, cls): 73 | """ 74 | Registers a new sample similarity measure. 75 | 76 | Args: 77 | 78 | name (str): Unique name of the sample similarity measure. 79 | 80 | cls (SampleSimilarityBase): Instance of :class:`SampleSimilarityBase`, implementing a new sample similarity 81 | measure. 82 | """ 83 | vassert(type(name) is str, "Sample similarity must be given a name") 84 | vassert(name.strip() == name, "Name must not have leading or trailing whitespaces") 85 | vassert(os.path.sep not in name, "Name must not contain path delimiters (slash/backslash)") 86 | vassert(name not in SAMPLE_SIMILARITY_REGISTRY, f'Sample similarity "{name}" is already registered') 87 | vassert( 88 | issubclass(cls, SampleSimilarityBase), "Sample similarity class must be subclassed from SampleSimilarityBase" 89 | ) 90 | SAMPLE_SIMILARITY_REGISTRY[name] = cls 91 | 92 | 93 | def register_noise_source(name, fn_generate): 94 | """ 95 | Registers a new noise source, which can generate samples to be used as inputs to generative models. 96 | 97 | Args: 98 | 99 | name (str): Unique name of the noise source. 100 | 101 | fn_generate (callable): Generator of a random samples of specified type and shape. Callable arguments: 102 | 103 | - `rng` (numpy.random.RandomState): random number generator state, initialized with \ 104 | :paramref:`~calculate_metrics.seed`. 105 | - `shape` (torch.Size): shape of the tensor of random samples. 106 | """ 107 | vassert(type(name) is str, "Noise source must be given a name") 108 | vassert(name.strip() == name, "Name must not have leading or trailing whitespaces") 109 | vassert(os.path.sep not in name, "Name must not contain path delimiters (slash/backslash)") 110 | vassert(name not in NOISE_SOURCE_REGISTRY, f'Noise source "{name}" is already registered') 111 | vassert( 112 | callable(fn_generate), 113 | "Noise source must be provided as a callable (function, lambda) with 2 arguments: rng, shape", 114 | ) 115 | NOISE_SOURCE_REGISTRY[name] = fn_generate 116 | 117 | 118 | def register_interpolation(name, fn_interpolate): 119 | """ 120 | Registers a new sample interpolation method. 121 | 122 | Args: 123 | 124 | name (str): Unique name of the interpolation method. 125 | 126 | fn_interpolate (callable): Sample interpolation function. Callable arguments: 127 | 128 | - `a` (torch.Tensor): batch of the first endpoint samples. 129 | - `b` (torch.Tensor): batch of the second endpoint samples. 130 | - `t` (float): interpolation coefficient in the range [0,1]. 131 | """ 132 | vassert(type(name) is str, "Interpolation must be given a name") 133 | vassert(name.strip() == name, "Name must not have leading or trailing whitespaces") 134 | vassert(os.path.sep not in name, "Name must not contain path delimiters (slash/backslash)") 135 | vassert(name not in INTERPOLATION_REGISTRY, f'Interpolation "{name}" is already registered') 136 | vassert( 137 | callable(fn_interpolate), 138 | "Interpolation must be provided as a callable (function, lambda) with 3 arguments: a, b, t", 139 | ) 140 | INTERPOLATION_REGISTRY[name] = fn_interpolate 141 | 142 | 143 | register_dataset( 144 | "cifar10-train", 145 | lambda root, download: Cifar10_RGB(root, train=True, transform=TransformPILtoRGBTensor(), download=download), 146 | ) 147 | register_dataset( 148 | "cifar10-val", 149 | lambda root, download: Cifar10_RGB(root, train=False, transform=TransformPILtoRGBTensor(), download=download), 150 | ) 151 | register_dataset( 152 | "cifar100-train", 153 | lambda root, download: Cifar100_RGB(root, train=True, transform=TransformPILtoRGBTensor(), download=download), 154 | ) 155 | register_dataset( 156 | "cifar100-val", 157 | lambda root, download: Cifar100_RGB(root, train=False, transform=TransformPILtoRGBTensor(), download=download), 158 | ) 159 | register_dataset( 160 | "stl10-train", 161 | lambda root, download: STL10_RGB(root, split="train", transform=TransformPILtoRGBTensor(), download=download), 162 | ) 163 | register_dataset( 164 | "stl10-test", 165 | lambda root, download: STL10_RGB(root, split="test", transform=TransformPILtoRGBTensor(), download=download), 166 | ) 167 | register_dataset( 168 | "stl10-unlabeled", 169 | lambda root, download: STL10_RGB(root, split="unlabeled", transform=TransformPILtoRGBTensor(), download=download), 170 | ) 171 | 172 | register_feature_extractor("inception-v3-compat", FeatureExtractorInceptionV3) 173 | 174 | register_feature_extractor("vgg16", FeatureExtractorVGG16) 175 | 176 | register_feature_extractor("clip-rn50", FeatureExtractorCLIP) 177 | register_feature_extractor("clip-rn101", FeatureExtractorCLIP) 178 | register_feature_extractor("clip-rn50x4", FeatureExtractorCLIP) 179 | register_feature_extractor("clip-rn50x16", FeatureExtractorCLIP) 180 | register_feature_extractor("clip-rn50x64", FeatureExtractorCLIP) 181 | register_feature_extractor("clip-vit-b-32", FeatureExtractorCLIP) 182 | register_feature_extractor("clip-vit-b-16", FeatureExtractorCLIP) 183 | register_feature_extractor("clip-vit-l-14", FeatureExtractorCLIP) 184 | register_feature_extractor("clip-vit-l-14-336px", FeatureExtractorCLIP) 185 | 186 | register_feature_extractor("dinov2-vit-s-14", FeatureExtractorDinoV2) 187 | register_feature_extractor("dinov2-vit-b-14", FeatureExtractorDinoV2) 188 | register_feature_extractor("dinov2-vit-l-14", FeatureExtractorDinoV2) 189 | register_feature_extractor("dinov2-vit-g-14", FeatureExtractorDinoV2) 190 | 191 | register_sample_similarity("lpips-vgg16", SampleSimilarityLPIPS) 192 | 193 | register_noise_source("normal", random_normal) 194 | register_noise_source("unit", random_unit) 195 | register_noise_source("uniform_0_1", random_uniform_0_1) 196 | 197 | register_interpolation("lerp", batch_lerp) 198 | register_interpolation("slerp_any", batch_slerp_any) 199 | register_interpolation("slerp_unit", batch_slerp_unit) 200 | -------------------------------------------------------------------------------- /torch_fidelity/sample_similarity_base.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from torch_fidelity.helpers import vassert 4 | 5 | 6 | class SampleSimilarityBase(nn.Module): 7 | def __init__(self, name): 8 | """ 9 | Base class for samples similarity measures that can be used in :func:`calculate_metrics`. 10 | 11 | Args: 12 | 13 | name (str): Unique name of the subclassed sample similarity measure, must be the same as used in 14 | :func:`register_sample_similarity`. 15 | """ 16 | super(SampleSimilarityBase, self).__init__() 17 | vassert(type(name) is str, "Sample similarity name must be a string") 18 | self.name = name 19 | 20 | def get_name(self): 21 | return self.name 22 | 23 | def forward(self, *args): 24 | """ 25 | Returns the value of sample similarity between the inputs. 26 | """ 27 | raise NotImplementedError 28 | -------------------------------------------------------------------------------- /torch_fidelity/sample_similarity_lpips.py: -------------------------------------------------------------------------------- 1 | # Adaptation of the following sources: 2 | # https://github.com/richzhang/PerceptualSimilarity/blob/master/lpips/pretrained_networks.py 3 | # https://github.com/richzhang/PerceptualSimilarity/blob/master/lpips/lpips.py 4 | # Distributed under BSD 2-Clause: https://github.com/richzhang/PerceptualSimilarity/blob/master/LICENSE 5 | import sys 6 | from contextlib import redirect_stdout 7 | 8 | import torch 9 | import torch.nn as nn 10 | from torch.hub import load_state_dict_from_url 11 | 12 | from torch_fidelity.helpers import vassert, text_to_dtype 13 | from torch_fidelity.sample_similarity_base import SampleSimilarityBase 14 | from torch_fidelity.utils_torchvision import torchvision_load_pretrained_vgg16 15 | 16 | # VGG16 LPIPS original weights re-uploaded from the following location: 17 | # https://github.com/richzhang/PerceptualSimilarity/blob/master/lpips/weights/v0.1/vgg.pth 18 | # Distributed under BSD 2-Clause: https://github.com/richzhang/PerceptualSimilarity/blob/master/LICENSE 19 | URL_VGG16_LPIPS = "https://github.com/toshas/torch-fidelity/releases/download/v0.2.0/weights-vgg16-lpips.pth" 20 | 21 | 22 | class VGG16features(torch.nn.Module): 23 | def __init__(self): 24 | super().__init__() 25 | vgg_pretrained_features = torchvision_load_pretrained_vgg16().features 26 | self.slice1 = torch.nn.Sequential() 27 | self.slice2 = torch.nn.Sequential() 28 | self.slice3 = torch.nn.Sequential() 29 | self.slice4 = torch.nn.Sequential() 30 | self.slice5 = torch.nn.Sequential() 31 | self.N_slices = 5 32 | for x in range(4): 33 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 34 | for x in range(4, 9): 35 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 36 | for x in range(9, 16): 37 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 38 | for x in range(16, 23): 39 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 40 | for x in range(23, 30): 41 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 42 | self.eval() 43 | for param in self.parameters(): 44 | param.requires_grad = False 45 | 46 | def forward(self, X): 47 | h = self.slice1(X) 48 | h_relu1_2 = h 49 | h = self.slice2(h) 50 | h_relu2_2 = h 51 | h = self.slice3(h) 52 | h_relu3_3 = h 53 | h = self.slice4(h) 54 | h_relu4_3 = h 55 | h = self.slice5(h) 56 | h_relu5_3 = h 57 | return h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3 58 | 59 | 60 | def spatial_average(in_tensor): 61 | return in_tensor.mean([2, 3]).squeeze(1) 62 | 63 | 64 | def normalize_tensor(in_features, eps=1e-10): 65 | norm_factor = torch.sqrt(torch.sum(in_features**2, dim=1, keepdim=True)) 66 | return in_features / (norm_factor + eps) 67 | 68 | 69 | class NetLinLayer(nn.Module): 70 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 71 | super(NetLinLayer, self).__init__() 72 | layers = ( 73 | [ 74 | nn.Dropout(), 75 | ] 76 | if use_dropout 77 | else [] 78 | ) 79 | layers += [ 80 | nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), 81 | ] 82 | self.model = nn.Sequential(*layers) 83 | 84 | 85 | class SampleSimilarityLPIPS(SampleSimilarityBase): 86 | def __init__(self, name, sample_similarity_resize=None, sample_similarity_dtype=None, **kwargs): 87 | """ 88 | LPIPS sample similarity measure for 2D RGB 24bit images. 89 | 90 | Args: 91 | 92 | name (str): Unique name of the sample similarity measure, must be the same as used in 93 | :func:`register_sample_similarity`. 94 | 95 | sample_similarity_resize (int or None): Resizes inputs to this size if set, keeps as is if `None`. 96 | 97 | sample_similarity_dtype (str): Coerces tensor dtype to one of the following: 'uint8', 'float32'. 98 | This is useful when the inputs are generated by a generative model, to ensure the proper data range and 99 | quantization. 100 | """ 101 | super(SampleSimilarityLPIPS, self).__init__(name) 102 | self.sample_similarity_resize = sample_similarity_resize 103 | self.sample_similarity_dtype = sample_similarity_dtype 104 | self.chns = [64, 128, 256, 512, 512] 105 | self.L = len(self.chns) 106 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=True) 107 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=True) 108 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=True) 109 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=True) 110 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=True) 111 | self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] 112 | with redirect_stdout(sys.stderr): 113 | state_dict = load_state_dict_from_url(URL_VGG16_LPIPS, map_location="cpu", progress=True) 114 | self.load_state_dict(state_dict) 115 | self.net = VGG16features() 116 | self.eval() 117 | for param in self.parameters(): 118 | param.requires_grad = False 119 | 120 | @staticmethod 121 | def normalize(x): 122 | # torchvision values in range [0,1] mean = [0.485, 0.456, 0.406] and std = [0.229, 0.224, 0.225] 123 | mean_rescaled = (1 + torch.tensor([-0.030, -0.088, -0.188], device=x.device)[None, :, None, None]) * 255 / 2 124 | inv_std_rescaled = 2 / (torch.tensor([0.458, 0.448, 0.450], device=x.device)[None, :, None, None] * 255) 125 | x = (x.float() - mean_rescaled) * inv_std_rescaled 126 | return x 127 | 128 | @staticmethod 129 | def resize(x, size): 130 | if x.shape[-1] > size and x.shape[-2] > size: 131 | x = torch.nn.functional.interpolate(x, (size, size), mode="area") 132 | else: 133 | x = torch.nn.functional.interpolate(x, (size, size), mode="bilinear", align_corners=False) 134 | return x 135 | 136 | def forward(self, in0, in1): 137 | vassert(torch.is_tensor(in0) and torch.is_tensor(in1), "Inputs must be torch tensors") 138 | vassert(in0.dim() == 4 and in0.shape[1] == 3, "Input 0 is not Bx3xHxW") 139 | vassert(in1.dim() == 4 and in1.shape[1] == 3, "Input 1 is not Bx3xHxW") 140 | if self.sample_similarity_dtype is not None: 141 | dtype = text_to_dtype(self.sample_similarity_dtype, None) 142 | vassert( 143 | dtype is not None and in0.dtype == dtype and in1.dtype == dtype, f"Unexpected input dtype ({in0.dtype})" 144 | ) 145 | in0_input = self.normalize(in0) 146 | in1_input = self.normalize(in1) 147 | 148 | if self.sample_similarity_resize is not None: 149 | in0_input = self.resize(in0_input, self.sample_similarity_resize) 150 | in1_input = self.resize(in1_input, self.sample_similarity_resize) 151 | 152 | outs0 = self.net.forward(in0_input) 153 | outs1 = self.net.forward(in1_input) 154 | 155 | feats0, feats1, diffs = {}, {}, {} 156 | 157 | for kk in range(self.L): 158 | feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) 159 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 160 | 161 | res = [spatial_average(self.lins[kk].model(diffs[kk])) for kk in range(self.L)] 162 | val = sum(res) 163 | return val 164 | -------------------------------------------------------------------------------- /torch_fidelity/utils_torch.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | 4 | import torch 5 | 6 | from torch_fidelity.helpers import vprint 7 | 8 | 9 | def torch_maybe_compile(module, dummy_input, verbose): 10 | out = module 11 | try: 12 | compiled = torch.compile(module) 13 | try: 14 | compiled(dummy_input) 15 | vprint(verbose, "Feature extractor compiled") 16 | setattr(out, "forward_pure", out.forward) 17 | setattr(out, "forward", compiled) 18 | except Exception: 19 | vprint(verbose, "Feature extractor compiled, but failed to run. Falling back to pure torch") 20 | except Exception as e: 21 | vprint(verbose, "Feature extractor compilation failed. Falling back to pure torch") 22 | return out 23 | 24 | 25 | def torch_atomic_save(what, path): 26 | path = os.path.expanduser(path) 27 | path_dir = os.path.dirname(path) 28 | fp = tempfile.NamedTemporaryFile(delete=False, dir=path_dir) 29 | try: 30 | torch.save(what, fp) 31 | fp.close() 32 | os.rename(fp.name, path) 33 | finally: 34 | fp.close() 35 | if os.path.exists(fp.name): 36 | os.remove(fp.name) 37 | -------------------------------------------------------------------------------- /torch_fidelity/utils_torchvision.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import warnings 3 | from contextlib import redirect_stdout 4 | 5 | import torchvision 6 | 7 | from torch_fidelity.helpers import get_kwarg 8 | 9 | 10 | def torchvision_load_pretrained_vgg16(**kwargs): 11 | verbose = get_kwarg("verbose", kwargs) 12 | with redirect_stdout(sys.stderr), warnings.catch_warnings(): 13 | warnings.filterwarnings("ignore", message="The parameter 'pretrained' is deprecated") 14 | warnings.filterwarnings("ignore", message="Arguments other than a weight enum") 15 | warnings.filterwarnings( 16 | "ignore", 17 | message="'torch.load' received a zip file that looks like a TorchScript " 18 | "archive dispatching to 'torch.jit.load'", 19 | ) 20 | try: 21 | out = torchvision.models.vgg16( 22 | weights=torchvision.models.VGG16_Weights.IMAGENET1K_V1, 23 | progress=verbose, 24 | ) 25 | except Exception: 26 | out = torchvision.models.vgg16( 27 | pretrained=True, 28 | progress=verbose, 29 | ) 30 | return out 31 | -------------------------------------------------------------------------------- /torch_fidelity/version.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.4.0-beta" 2 | -------------------------------------------------------------------------------- /torch_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | # empty 9 | -------------------------------------------------------------------------------- /torch_utils/distributed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | import os 9 | import torch 10 | from . import training_stats 11 | 12 | #---------------------------------------------------------------------------- 13 | 14 | def init(): 15 | if 'MASTER_ADDR' not in os.environ: 16 | os.environ['MASTER_ADDR'] = 'localhost' 17 | # if 'MASTER_PORT' not in os.environ: 18 | os.environ['MASTER_PORT'] = '29500' 19 | if 'RANK' not in os.environ: 20 | os.environ['RANK'] = '0' 21 | if 'LOCAL_RANK' not in os.environ: 22 | os.environ['LOCAL_RANK'] = '0' 23 | if 'WORLD_SIZE' not in os.environ: 24 | os.environ['WORLD_SIZE'] = '1' 25 | 26 | backend = 'gloo' if os.name == 'nt' else 'nccl' 27 | torch.distributed.init_process_group(backend=backend, init_method='env://') 28 | torch.cuda.set_device(int(os.environ.get('LOCAL_RANK', '0'))) 29 | 30 | sync_device = torch.device('cuda') if get_world_size() > 1 else None 31 | training_stats.init_multiprocessing(rank=get_rank(), sync_device=sync_device) 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | def get_rank(): 36 | return torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 37 | 38 | #---------------------------------------------------------------------------- 39 | 40 | def get_world_size(): 41 | return torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1 42 | 43 | #---------------------------------------------------------------------------- 44 | 45 | def should_stop(): 46 | return False 47 | 48 | #---------------------------------------------------------------------------- 49 | 50 | def update_progress(cur, total): 51 | _ = cur, total 52 | 53 | #---------------------------------------------------------------------------- 54 | 55 | def print0(*args, **kwargs): 56 | if get_rank() == 0: 57 | print(*args, **kwargs) 58 | 59 | #---------------------------------------------------------------------------- 60 | -------------------------------------------------------------------------------- /torch_utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | import re 9 | import contextlib 10 | import numpy as np 11 | import torch 12 | import warnings 13 | import dnnlib 14 | 15 | #---------------------------------------------------------------------------- 16 | # Cached construction of constant tensors. Avoids CPU=>GPU copy when the 17 | # same constant is used multiple times. 18 | 19 | _constant_cache = dict() 20 | 21 | def constant(value, shape=None, dtype=None, device=None, memory_format=None): 22 | value = np.asarray(value) 23 | if shape is not None: 24 | shape = tuple(shape) 25 | if dtype is None: 26 | dtype = torch.get_default_dtype() 27 | if device is None: 28 | device = torch.device('cpu') 29 | if memory_format is None: 30 | memory_format = torch.contiguous_format 31 | 32 | key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format) 33 | tensor = _constant_cache.get(key, None) 34 | if tensor is None: 35 | tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device) 36 | if shape is not None: 37 | tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape)) 38 | tensor = tensor.contiguous(memory_format=memory_format) 39 | _constant_cache[key] = tensor 40 | return tensor 41 | 42 | #---------------------------------------------------------------------------- 43 | # Replace NaN/Inf with specified numerical values. 44 | 45 | try: 46 | nan_to_num = torch.nan_to_num # 1.8.0a0 47 | except AttributeError: 48 | def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin 49 | assert isinstance(input, torch.Tensor) 50 | if posinf is None: 51 | posinf = torch.finfo(input.dtype).max 52 | if neginf is None: 53 | neginf = torch.finfo(input.dtype).min 54 | assert nan == 0 55 | return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out) 56 | 57 | #---------------------------------------------------------------------------- 58 | # Symbolic assert. 59 | 60 | try: 61 | symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access 62 | except AttributeError: 63 | symbolic_assert = torch.Assert # 1.7.0 64 | 65 | #---------------------------------------------------------------------------- 66 | # Context manager to temporarily suppress known warnings in torch.jit.trace(). 67 | # Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672 68 | 69 | @contextlib.contextmanager 70 | def suppress_tracer_warnings(): 71 | flt = ('ignore', None, torch.jit.TracerWarning, None, 0) 72 | warnings.filters.insert(0, flt) 73 | yield 74 | warnings.filters.remove(flt) 75 | 76 | #---------------------------------------------------------------------------- 77 | # Assert that the shape of a tensor matches the given list of integers. 78 | # None indicates that the size of a dimension is allowed to vary. 79 | # Performs symbolic assertion when used in torch.jit.trace(). 80 | 81 | def assert_shape(tensor, ref_shape): 82 | if tensor.ndim != len(ref_shape): 83 | raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}') 84 | for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)): 85 | if ref_size is None: 86 | pass 87 | elif isinstance(ref_size, torch.Tensor): 88 | with suppress_tracer_warnings(): # as_tensor results are registered as constants 89 | symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}') 90 | elif isinstance(size, torch.Tensor): 91 | with suppress_tracer_warnings(): # as_tensor results are registered as constants 92 | symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}') 93 | elif size != ref_size: 94 | raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}') 95 | 96 | #---------------------------------------------------------------------------- 97 | # Function decorator that calls torch.autograd.profiler.record_function(). 98 | 99 | def profiled_function(fn): 100 | def decorator(*args, **kwargs): 101 | with torch.autograd.profiler.record_function(fn.__name__): 102 | return fn(*args, **kwargs) 103 | decorator.__name__ = fn.__name__ 104 | return decorator 105 | 106 | #---------------------------------------------------------------------------- 107 | # Sampler for torch.utils.data.DataLoader that loops over the dataset 108 | # indefinitely, shuffling items as it goes. 109 | 110 | class InfiniteSampler(torch.utils.data.Sampler): 111 | def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5): 112 | assert len(dataset) > 0 113 | assert num_replicas > 0 114 | assert 0 <= rank < num_replicas 115 | assert 0 <= window_size <= 1 116 | super().__init__(dataset) 117 | self.dataset = dataset 118 | self.rank = rank 119 | self.num_replicas = num_replicas 120 | self.shuffle = shuffle 121 | self.seed = seed 122 | self.window_size = window_size 123 | 124 | def __iter__(self): 125 | order = np.arange(len(self.dataset)) 126 | rnd = None 127 | window = 0 128 | if self.shuffle: 129 | rnd = np.random.RandomState(self.seed) 130 | rnd.shuffle(order) 131 | window = int(np.rint(order.size * self.window_size)) 132 | 133 | idx = 0 134 | while True: 135 | i = idx % order.size 136 | if idx % self.num_replicas == self.rank: 137 | yield order[i] 138 | if window >= 2: 139 | j = (i - rnd.randint(window)) % order.size 140 | order[i], order[j] = order[j], order[i] 141 | idx += 1 142 | 143 | #---------------------------------------------------------------------------- 144 | # Utilities for operating with torch.nn.Module parameters and buffers. 145 | 146 | def params_and_buffers(module): 147 | assert isinstance(module, torch.nn.Module) 148 | return list(module.parameters()) + list(module.buffers()) 149 | 150 | def named_params_and_buffers(module): 151 | assert isinstance(module, torch.nn.Module) 152 | return list(module.named_parameters()) + list(module.named_buffers()) 153 | 154 | @torch.no_grad() 155 | def copy_params_and_buffers(src_module, dst_module, require_all=False): 156 | assert isinstance(src_module, torch.nn.Module) 157 | assert isinstance(dst_module, torch.nn.Module) 158 | src_tensors = dict(named_params_and_buffers(src_module)) 159 | for name, tensor in named_params_and_buffers(dst_module): 160 | assert (name in src_tensors) or (not require_all) 161 | if name in src_tensors: 162 | tensor.copy_(src_tensors[name]) 163 | 164 | #---------------------------------------------------------------------------- 165 | # Context manager for easily enabling/disabling DistributedDataParallel 166 | # synchronization. 167 | 168 | @contextlib.contextmanager 169 | def ddp_sync(module, sync): 170 | assert isinstance(module, torch.nn.Module) 171 | if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel): 172 | yield 173 | else: 174 | with module.no_sync(): 175 | yield 176 | 177 | #---------------------------------------------------------------------------- 178 | # Check DistributedDataParallel consistency across processes. 179 | 180 | def check_ddp_consistency(module, ignore_regex=None): 181 | assert isinstance(module, torch.nn.Module) 182 | for name, tensor in named_params_and_buffers(module): 183 | fullname = type(module).__name__ + '.' + name 184 | if ignore_regex is not None and re.fullmatch(ignore_regex, fullname): 185 | continue 186 | tensor = tensor.detach() 187 | if tensor.is_floating_point(): 188 | tensor = nan_to_num(tensor) 189 | other = tensor.clone() 190 | torch.distributed.broadcast(tensor=other, src=0) 191 | assert (tensor == other).all(), fullname 192 | 193 | #---------------------------------------------------------------------------- 194 | # Print summary table of module hierarchy. 195 | 196 | def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True): 197 | assert isinstance(module, torch.nn.Module) 198 | assert not isinstance(module, torch.jit.ScriptModule) 199 | assert isinstance(inputs, (tuple, list)) 200 | 201 | # Register hooks. 202 | entries = [] 203 | nesting = [0] 204 | def pre_hook(_mod, _inputs): 205 | nesting[0] += 1 206 | def post_hook(mod, _inputs, outputs): 207 | nesting[0] -= 1 208 | if nesting[0] <= max_nesting: 209 | outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs] 210 | outputs = [t for t in outputs if isinstance(t, torch.Tensor)] 211 | entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs)) 212 | hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()] 213 | hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()] 214 | 215 | # Run module. 216 | outputs = module(*inputs) 217 | for hook in hooks: 218 | hook.remove() 219 | 220 | # Identify unique outputs, parameters, and buffers. 221 | tensors_seen = set() 222 | for e in entries: 223 | e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen] 224 | e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen] 225 | e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen] 226 | tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs} 227 | 228 | # Filter out redundant entries. 229 | if skip_redundant: 230 | entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)] 231 | 232 | # Construct table. 233 | rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']] 234 | rows += [['---'] * len(rows[0])] 235 | param_total = 0 236 | buffer_total = 0 237 | submodule_names = {mod: name for name, mod in module.named_modules()} 238 | for e in entries: 239 | name = '' if e.mod is module else submodule_names[e.mod] 240 | param_size = sum(t.numel() for t in e.unique_params) 241 | buffer_size = sum(t.numel() for t in e.unique_buffers) 242 | output_shapes = [str(list(t.shape)) for t in e.outputs] 243 | output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs] 244 | rows += [[ 245 | name + (':0' if len(e.outputs) >= 2 else ''), 246 | str(param_size) if param_size else '-', 247 | str(buffer_size) if buffer_size else '-', 248 | (output_shapes + ['-'])[0], 249 | (output_dtypes + ['-'])[0], 250 | ]] 251 | for idx in range(1, len(e.outputs)): 252 | rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]] 253 | param_total += param_size 254 | buffer_total += buffer_size 255 | rows += [['---'] * len(rows[0])] 256 | rows += [['Total', str(param_total), str(buffer_total), '-', '-']] 257 | 258 | # Print table. 259 | widths = [max(len(cell) for cell in column) for column in zip(*rows)] 260 | print() 261 | for row in rows: 262 | print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths))) 263 | print() 264 | return outputs 265 | 266 | #---------------------------------------------------------------------------- 267 | -------------------------------------------------------------------------------- /torch_utils/persistence.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Facilities for pickling Python code alongside other data. 9 | 10 | The pickled code is automatically imported into a separate Python module 11 | during unpickling. This way, any previously exported pickles will remain 12 | usable even if the original code is no longer available, or if the current 13 | version of the code is not consistent with what was originally pickled.""" 14 | 15 | import sys 16 | import pickle 17 | import io 18 | import inspect 19 | import copy 20 | import uuid 21 | import types 22 | import dnnlib 23 | 24 | #---------------------------------------------------------------------------- 25 | 26 | _version = 6 # internal version number 27 | _decorators = set() # {decorator_class, ...} 28 | _import_hooks = [] # [hook_function, ...] 29 | _module_to_src_dict = dict() # {module: src, ...} 30 | _src_to_module_dict = dict() # {src: module, ...} 31 | 32 | #---------------------------------------------------------------------------- 33 | 34 | def persistent_class(orig_class): 35 | r"""Class decorator that extends a given class to save its source code 36 | when pickled. 37 | 38 | Example: 39 | 40 | from torch_utils import persistence 41 | 42 | @persistence.persistent_class 43 | class MyNetwork(torch.nn.Module): 44 | def __init__(self, num_inputs, num_outputs): 45 | super().__init__() 46 | self.fc = MyLayer(num_inputs, num_outputs) 47 | ... 48 | 49 | @persistence.persistent_class 50 | class MyLayer(torch.nn.Module): 51 | ... 52 | 53 | When pickled, any instance of `MyNetwork` and `MyLayer` will save its 54 | source code alongside other internal state (e.g., parameters, buffers, 55 | and submodules). This way, any previously exported pickle will remain 56 | usable even if the class definitions have been modified or are no 57 | longer available. 58 | 59 | The decorator saves the source code of the entire Python module 60 | containing the decorated class. It does *not* save the source code of 61 | any imported modules. Thus, the imported modules must be available 62 | during unpickling, also including `torch_utils.persistence` itself. 63 | 64 | It is ok to call functions defined in the same module from the 65 | decorated class. However, if the decorated class depends on other 66 | classes defined in the same module, they must be decorated as well. 67 | This is illustrated in the above example in the case of `MyLayer`. 68 | 69 | It is also possible to employ the decorator just-in-time before 70 | calling the constructor. For example: 71 | 72 | cls = MyLayer 73 | if want_to_make_it_persistent: 74 | cls = persistence.persistent_class(cls) 75 | layer = cls(num_inputs, num_outputs) 76 | 77 | As an additional feature, the decorator also keeps track of the 78 | arguments that were used to construct each instance of the decorated 79 | class. The arguments can be queried via `obj.init_args` and 80 | `obj.init_kwargs`, and they are automatically pickled alongside other 81 | object state. This feature can be disabled on a per-instance basis 82 | by setting `self._record_init_args = False` in the constructor. 83 | 84 | A typical use case is to first unpickle a previous instance of a 85 | persistent class, and then upgrade it to use the latest version of 86 | the source code: 87 | 88 | with open('old_pickle.pkl', 'rb') as f: 89 | old_net = pickle.load(f) 90 | new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs) 91 | misc.copy_params_and_buffers(old_net, new_net, require_all=True) 92 | """ 93 | assert isinstance(orig_class, type) 94 | if is_persistent(orig_class): 95 | return orig_class 96 | 97 | assert orig_class.__module__ in sys.modules 98 | orig_module = sys.modules[orig_class.__module__] 99 | orig_module_src = _module_to_src(orig_module) 100 | 101 | class Decorator(orig_class): 102 | _orig_module_src = orig_module_src 103 | _orig_class_name = orig_class.__name__ 104 | 105 | def __init__(self, *args, **kwargs): 106 | super().__init__(*args, **kwargs) 107 | record_init_args = getattr(self, '_record_init_args', True) 108 | self._init_args = copy.deepcopy(args) if record_init_args else None 109 | self._init_kwargs = copy.deepcopy(kwargs) if record_init_args else None 110 | assert orig_class.__name__ in orig_module.__dict__ 111 | _check_pickleable(self.__reduce__()) 112 | 113 | @property 114 | def init_args(self): 115 | assert self._init_args is not None 116 | return copy.deepcopy(self._init_args) 117 | 118 | @property 119 | def init_kwargs(self): 120 | assert self._init_kwargs is not None 121 | return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs)) 122 | 123 | def __reduce__(self): 124 | fields = list(super().__reduce__()) 125 | fields += [None] * max(3 - len(fields), 0) 126 | if fields[0] is not _reconstruct_persistent_obj: 127 | meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2]) 128 | fields[0] = _reconstruct_persistent_obj # reconstruct func 129 | fields[1] = (meta,) # reconstruct args 130 | fields[2] = None # state dict 131 | return tuple(fields) 132 | 133 | Decorator.__name__ = orig_class.__name__ 134 | Decorator.__module__ = orig_class.__module__ 135 | _decorators.add(Decorator) 136 | return Decorator 137 | 138 | #---------------------------------------------------------------------------- 139 | 140 | def is_persistent(obj): 141 | r"""Test whether the given object or class is persistent, i.e., 142 | whether it will save its source code when pickled. 143 | """ 144 | try: 145 | if obj in _decorators: 146 | return True 147 | except TypeError: 148 | pass 149 | return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck 150 | 151 | #---------------------------------------------------------------------------- 152 | 153 | def import_hook(hook): 154 | r"""Register an import hook that is called whenever a persistent object 155 | is being unpickled. A typical use case is to patch the pickled source 156 | code to avoid errors and inconsistencies when the API of some imported 157 | module has changed. 158 | 159 | The hook should have the following signature: 160 | 161 | hook(meta) -> modified meta 162 | 163 | `meta` is an instance of `dnnlib.EasyDict` with the following fields: 164 | 165 | type: Type of the persistent object, e.g. `'class'`. 166 | version: Internal version number of `torch_utils.persistence`. 167 | module_src Original source code of the Python module. 168 | class_name: Class name in the original Python module. 169 | state: Internal state of the object. 170 | 171 | Example: 172 | 173 | @persistence.import_hook 174 | def wreck_my_network(meta): 175 | if meta.class_name == 'MyNetwork': 176 | print('MyNetwork is being imported. I will wreck it!') 177 | meta.module_src = meta.module_src.replace("True", "False") 178 | return meta 179 | """ 180 | assert callable(hook) 181 | _import_hooks.append(hook) 182 | 183 | #---------------------------------------------------------------------------- 184 | 185 | def _reconstruct_persistent_obj(meta): 186 | r"""Hook that is called internally by the `pickle` module to unpickle 187 | a persistent object. 188 | """ 189 | meta = dnnlib.EasyDict(meta) 190 | meta.state = dnnlib.EasyDict(meta.state) 191 | for hook in _import_hooks: 192 | meta = hook(meta) 193 | assert meta is not None 194 | 195 | assert meta.version == _version 196 | module = _src_to_module(meta.module_src) 197 | 198 | assert meta.type == 'class' 199 | orig_class = module.__dict__[meta.class_name] 200 | decorator_class = persistent_class(orig_class) 201 | obj = decorator_class.__new__(decorator_class) 202 | 203 | setstate = getattr(obj, '__setstate__', None) 204 | if callable(setstate): 205 | setstate(meta.state) # pylint: disable=not-callable 206 | else: 207 | obj.__dict__.update(meta.state) 208 | return obj 209 | 210 | #---------------------------------------------------------------------------- 211 | 212 | def _module_to_src(module): 213 | r"""Query the source code of a given Python module. 214 | """ 215 | src = _module_to_src_dict.get(module, None) 216 | if src is None: 217 | src = inspect.getsource(module) 218 | _module_to_src_dict[module] = src 219 | _src_to_module_dict[src] = module 220 | return src 221 | 222 | def _src_to_module(src): 223 | r"""Get or create a Python module for the given source code. 224 | """ 225 | module = _src_to_module_dict.get(src, None) 226 | if module is None: 227 | module_name = "_imported_module_" + uuid.uuid4().hex 228 | module = types.ModuleType(module_name) 229 | sys.modules[module_name] = module 230 | _module_to_src_dict[module] = src 231 | _src_to_module_dict[src] = module 232 | exec(src, module.__dict__) # pylint: disable=exec-used 233 | return module 234 | 235 | #---------------------------------------------------------------------------- 236 | 237 | def _check_pickleable(obj): 238 | r"""Check that the given object is pickleable, raising an exception if 239 | it is not. This function is expected to be considerably more efficient 240 | than actually pickling the object. 241 | """ 242 | def recurse(obj): 243 | if isinstance(obj, (list, tuple, set)): 244 | return [recurse(x) for x in obj] 245 | if isinstance(obj, dict): 246 | return [[recurse(x), recurse(y)] for x, y in obj.items()] 247 | if isinstance(obj, (str, int, float, bool, bytes, bytearray)): 248 | return None # Python primitive types are pickleable. 249 | if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor', 'torch.nn.parameter.Parameter']: 250 | return None # NumPy arrays and PyTorch tensors are pickleable. 251 | if is_persistent(obj): 252 | return None # Persistent objects are pickleable, by virtue of the constructor check. 253 | return obj 254 | with io.BytesIO() as f: 255 | pickle.dump(recurse(obj), f) 256 | 257 | #---------------------------------------------------------------------------- 258 | -------------------------------------------------------------------------------- /torch_utils/torch_dct.py: -------------------------------------------------------------------------------- 1 | """Taken from https://github.com/zh217/torch-dct/blob/master/torch_dct/_dct.py 2 | Some modifications have been made to work with newer versions of Pytorch""" 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import time 8 | 9 | def re_arrange(img, img_size): 10 | b, c, h, w = img.shape 11 | x = img.reshape(b, c, h//img_size, img_size, w//img_size, img_size).permute(0, 2, 4, 1, 3, 5).reshape(-1, c, img_size, img_size) 12 | return x 13 | 14 | 15 | def recover(img, origin_img_size): 16 | img_size = img.shape[-1] 17 | x = img.reshape(-1, origin_img_size//img_size, origin_img_size//img_size, 3, img_size, img_size).permute(0, 3, 1, 4, 2, 5).reshape(-1, 3, origin_img_size, origin_img_size) 18 | return x 19 | 20 | 21 | def dct(x, norm=None): 22 | """ 23 | Discrete Cosine Transform, Type II (a.k.a. the DCT) 24 | For the meaning of the parameter `norm`, see: 25 | https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html 26 | :param x: the input signal 27 | :param norm: the normalization, None or 'ortho' 28 | :return: the DCT-II of the signal over the last dimension 29 | """ 30 | x_shape = x.shape 31 | N = x_shape[-1] 32 | x = x.contiguous().view(-1, N) 33 | 34 | v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1) 35 | 36 | #Vc = torch.fft.rfft(v, 1) 37 | Vc = torch.view_as_real(torch.fft.fft(v, dim=1)) 38 | 39 | k = - torch.arange(N, dtype=x.dtype, 40 | device=x.device)[None, :] * np.pi / (2 * N) 41 | W_r = torch.cos(k) 42 | W_i = torch.sin(k) 43 | 44 | V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i 45 | 46 | if norm == 'ortho': 47 | V[:, 0] /= np.sqrt(N) * 2 48 | V[:, 1:] /= np.sqrt(N / 2) * 2 49 | 50 | V = 2 * V.view(*x_shape) 51 | 52 | return V 53 | 54 | 55 | def idct(X, norm=None): 56 | """ 57 | The inverse to DCT-II, which is a scaled Discrete Cosine Transform, Type III 58 | Our definition of idct is that idct(dct(x)) == x 59 | For the meaning of the parameter `norm`, see: 60 | https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html 61 | :param X: the input signal 62 | :param norm: the normalization, None or 'ortho' 63 | :return: the inverse DCT-II of the signal over the last dimension 64 | """ 65 | 66 | x_shape = X.shape 67 | N = x_shape[-1] 68 | 69 | X_v = X.contiguous().view(-1, x_shape[-1]) / 2 70 | 71 | if norm == 'ortho': 72 | X_v[:, 0] *= np.sqrt(N) * 2 73 | X_v[:, 1:] *= np.sqrt(N / 2) * 2 74 | 75 | k = torch.arange(x_shape[-1], dtype=X.dtype, 76 | device=X.device)[None, :] * np.pi / (2 * N) 77 | W_r = torch.cos(k) 78 | W_i = torch.sin(k) 79 | 80 | V_t_r = X_v 81 | V_t_i = torch.cat([X_v[:, :1] * 0, -X_v.flip([1])[:, :-1]], dim=1) 82 | 83 | V_r = V_t_r * W_r - V_t_i * W_i 84 | V_i = V_t_r * W_i + V_t_i * W_r 85 | 86 | V = torch.cat([V_r.unsqueeze(2), V_i.unsqueeze(2)], dim=2) 87 | 88 | #v = torch.fft.irfft(V, 1) 89 | v = torch.fft.irfft(torch.view_as_complex(V), n=V.shape[1], dim=1) 90 | x = v.new_zeros(v.shape) 91 | x[:, ::2] += v[:, :N - (N // 2)] 92 | x[:, 1::2] += v.flip([1])[:, :N // 2] 93 | 94 | return x.view(*x_shape) 95 | 96 | 97 | def dct_2d(x, size, norm=None): 98 | """ 99 | 2-dimentional Discrete Cosine Transform, Type II (a.k.a. the DCT) 100 | For the meaning of the parameter `norm`, see: 101 | https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html 102 | :param x: the input signal 103 | :param norm: the normalization, None or 'ortho' 104 | :return: the DCT-II of the signal over the last 2 dimensions 105 | """ 106 | start = time.time() 107 | origin_size = x.shape[-1] 108 | if origin_size > size: 109 | x = re_arrange(x, size) 110 | # print('time1:', time.time()-start) 111 | X1 = dct(x, norm=norm) 112 | X2 = dct(X1.transpose(-1, -2), norm=norm) 113 | X2 = X2.transpose(-1, -2) 114 | # print('time2:', time.time()-start) 115 | if origin_size > size: 116 | X2 = recover(X2, origin_size) 117 | # print('time3:', time.time()-start) 118 | return X2 119 | 120 | def idct_2d(X, size, norm=None): 121 | """ 122 | The inverse to 2D DCT-II, which is a scaled Discrete Cosine Transform, Type III 123 | Our definition of idct is that idct_2d(dct_2d(x)) == x 124 | For the meaning of the parameter `norm`, see: 125 | https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html 126 | :param X: the input signal 127 | :param norm: the normalization, None or 'ortho' 128 | :return: the DCT-II of the signal over the last 2 dimensions 129 | """ 130 | start = time.time() 131 | origin_size = X.shape[-1] 132 | if origin_size > size: 133 | X = re_arrange(X, size) 134 | 135 | x1 = idct(X, norm=norm) 136 | x2 = idct(x1.transpose(-1, -2), norm=norm) 137 | x2 = x2.transpose(-1, -2) 138 | 139 | if origin_size > size: 140 | x2 = recover(x2, origin_size) 141 | # print('time4:', time.time()-start) 142 | return x2 143 | 144 | def dct_3d(x, norm=None): 145 | """ 146 | 3-dimentional Discrete Cosine Transform, Type II (a.k.a. the DCT) 147 | For the meaning of the parameter `norm`, see: 148 | https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html 149 | :param x: the input signal 150 | :param norm: the normalization, None or 'ortho' 151 | :return: the DCT-II of the signal over the last 3 dimensions 152 | """ 153 | X1 = dct(x, norm=norm) 154 | X2 = dct(X1.transpose(-1, -2), norm=norm) 155 | X3 = dct(X2.transpose(-1, -3), norm=norm) 156 | return X3.transpose(-1, -3).transpose(-1, -2) 157 | 158 | 159 | def idct_3d(X, norm=None): 160 | """ 161 | The inverse to 3D DCT-II, which is a scaled Discrete Cosine Transform, Type III 162 | Our definition of idct is that idct_3d(dct_3d(x)) == x 163 | For the meaning of the parameter `norm`, see: 164 | https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html 165 | :param X: the input signal 166 | :param norm: the normalization, None or 'ortho' 167 | :return: the DCT-II of the signal over the last 3 dimensions 168 | """ 169 | x1 = idct(X, norm=norm) 170 | x2 = idct(x1.transpose(-1, -2), norm=norm) 171 | x3 = idct(x2.transpose(-1, -3), norm=norm) 172 | return x3.transpose(-1, -3).transpose(-1, -2) 173 | 174 | 175 | class LinearDCT(nn.Linear): 176 | """Implement any DCT as a linear layer; in practice this executes around 177 | 50x faster on GPU. Unfortunately, the DCT matrix is stored, which will 178 | increase memory usage. 179 | :param in_features: size of expected input 180 | :param type: which dct function in this file to use""" 181 | 182 | def __init__(self, in_features, type, norm=None, bias=False): 183 | self.type = type 184 | self.N = in_features 185 | self.norm = norm 186 | super(LinearDCT, self).__init__(in_features, in_features, bias=bias) 187 | 188 | def reset_parameters(self): 189 | # initialise using dct function 190 | I = torch.eye(self.N) 191 | if self.type == 'dct': 192 | self.weight.data = dct(I, norm=self.norm).data.t() 193 | elif self.type == 'idct': 194 | self.weight.data = idct(I, norm=self.norm).data.t() 195 | self.weight.requires_grad = False # don't learn this! 196 | 197 | 198 | def apply_linear_2d(x, linear_layer): 199 | """Can be used with a LinearDCT layer to do a 2D DCT. 200 | :param x: the input signal 201 | :param linear_layer: any PyTorch Linear layer 202 | :return: result of linear layer applied to last 2 dimensions 203 | """ 204 | X1 = linear_layer(x) 205 | X2 = linear_layer(X1.transpose(-1, -2)) 206 | return X2.transpose(-1, -2) 207 | 208 | 209 | def apply_linear_3d(x, linear_layer): 210 | """Can be used with a LinearDCT layer to do a 3D DCT. 211 | :param x: the input signal 212 | :param linear_layer: any PyTorch Linear layer 213 | :return: result of linear layer applied to last 3 dimensions 214 | """ 215 | X1 = linear_layer(x) 216 | X2 = linear_layer(X1.transpose(-1, -2)) 217 | X3 = linear_layer(X2.transpose(-1, -3)) 218 | return X3.transpose(-1, -3).transpose(-1, -2) 219 | 220 | 221 | if __name__ == '__main__': 222 | x = torch.Tensor(1000, 4096) 223 | x.normal_(0, 1) 224 | linear_dct = LinearDCT(4096, 'dct') 225 | error = torch.abs(dct(x) - linear_dct(x)) 226 | assert error.max() < 1e-3, (error, error.max()) 227 | linear_idct = LinearDCT(4096, 'idct') 228 | error = torch.abs(idct(x) - linear_idct(x)) 229 | assert error.max() < 1e-3, (error, error.max()) 230 | -------------------------------------------------------------------------------- /torch_utils/training_stats.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Facilities for reporting and collecting training statistics across 9 | multiple processes and devices. The interface is designed to minimize 10 | synchronization overhead as well as the amount of boilerplate in user 11 | code.""" 12 | 13 | import re 14 | import numpy as np 15 | import torch 16 | import dnnlib 17 | 18 | from . import misc 19 | 20 | #---------------------------------------------------------------------------- 21 | 22 | _num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares] 23 | _reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction. 24 | _counter_dtype = torch.float64 # Data type to use for the internal counters. 25 | _rank = 0 # Rank of the current process. 26 | _sync_device = None # Device to use for multiprocess communication. None = single-process. 27 | _sync_called = False # Has _sync() been called yet? 28 | _counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor 29 | _cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor 30 | 31 | #---------------------------------------------------------------------------- 32 | 33 | def init_multiprocessing(rank, sync_device): 34 | r"""Initializes `torch_utils.training_stats` for collecting statistics 35 | across multiple processes. 36 | 37 | This function must be called after 38 | `torch.distributed.init_process_group()` and before `Collector.update()`. 39 | The call is not necessary if multi-process collection is not needed. 40 | 41 | Args: 42 | rank: Rank of the current process. 43 | sync_device: PyTorch device to use for inter-process 44 | communication, or None to disable multi-process 45 | collection. Typically `torch.device('cuda', rank)`. 46 | """ 47 | global _rank, _sync_device 48 | assert not _sync_called 49 | _rank = rank 50 | _sync_device = sync_device 51 | 52 | #---------------------------------------------------------------------------- 53 | 54 | @misc.profiled_function 55 | def report(name, value): 56 | r"""Broadcasts the given set of scalars to all interested instances of 57 | `Collector`, across device and process boundaries. 58 | 59 | This function is expected to be extremely cheap and can be safely 60 | called from anywhere in the training loop, loss function, or inside a 61 | `torch.nn.Module`. 62 | 63 | Warning: The current implementation expects the set of unique names to 64 | be consistent across processes. Please make sure that `report()` is 65 | called at least once for each unique name by each process, and in the 66 | same order. If a given process has no scalars to broadcast, it can do 67 | `report(name, [])` (empty list). 68 | 69 | Args: 70 | name: Arbitrary string specifying the name of the statistic. 71 | Averages are accumulated separately for each unique name. 72 | value: Arbitrary set of scalars. Can be a list, tuple, 73 | NumPy array, PyTorch tensor, or Python scalar. 74 | 75 | Returns: 76 | The same `value` that was passed in. 77 | """ 78 | if name not in _counters: 79 | _counters[name] = dict() 80 | 81 | elems = torch.as_tensor(value) 82 | if elems.numel() == 0: 83 | return value 84 | 85 | elems = elems.detach().flatten().to(_reduce_dtype) 86 | moments = torch.stack([ 87 | torch.ones_like(elems).sum(), 88 | elems.sum(), 89 | elems.square().sum(), 90 | ]) 91 | assert moments.ndim == 1 and moments.shape[0] == _num_moments 92 | moments = moments.to(_counter_dtype) 93 | 94 | device = moments.device 95 | if device not in _counters[name]: 96 | _counters[name][device] = torch.zeros_like(moments) 97 | _counters[name][device].add_(moments) 98 | return value 99 | 100 | #---------------------------------------------------------------------------- 101 | 102 | def report0(name, value): 103 | r"""Broadcasts the given set of scalars by the first process (`rank = 0`), 104 | but ignores any scalars provided by the other processes. 105 | See `report()` for further details. 106 | """ 107 | report(name, value if _rank == 0 else []) 108 | return value 109 | 110 | #---------------------------------------------------------------------------- 111 | 112 | class Collector: 113 | r"""Collects the scalars broadcasted by `report()` and `report0()` and 114 | computes their long-term averages (mean and standard deviation) over 115 | user-defined periods of time. 116 | 117 | The averages are first collected into internal counters that are not 118 | directly visible to the user. They are then copied to the user-visible 119 | state as a result of calling `update()` and can then be queried using 120 | `mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the 121 | internal counters for the next round, so that the user-visible state 122 | effectively reflects averages collected between the last two calls to 123 | `update()`. 124 | 125 | Args: 126 | regex: Regular expression defining which statistics to 127 | collect. The default is to collect everything. 128 | keep_previous: Whether to retain the previous averages if no 129 | scalars were collected on a given round 130 | (default: True). 131 | """ 132 | def __init__(self, regex='.*', keep_previous=True): 133 | self._regex = re.compile(regex) 134 | self._keep_previous = keep_previous 135 | self._cumulative = dict() 136 | self._moments = dict() 137 | self.update() 138 | self._moments.clear() 139 | 140 | def names(self): 141 | r"""Returns the names of all statistics broadcasted so far that 142 | match the regular expression specified at construction time. 143 | """ 144 | return [name for name in _counters if self._regex.fullmatch(name)] 145 | 146 | def update(self): 147 | r"""Copies current values of the internal counters to the 148 | user-visible state and resets them for the next round. 149 | 150 | If `keep_previous=True` was specified at construction time, the 151 | operation is skipped for statistics that have received no scalars 152 | since the last update, retaining their previous averages. 153 | 154 | This method performs a number of GPU-to-CPU transfers and one 155 | `torch.distributed.all_reduce()`. It is intended to be called 156 | periodically in the main training loop, typically once every 157 | N training steps. 158 | """ 159 | if not self._keep_previous: 160 | self._moments.clear() 161 | for name, cumulative in _sync(self.names()): 162 | if name not in self._cumulative: 163 | self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) 164 | delta = cumulative - self._cumulative[name] 165 | self._cumulative[name].copy_(cumulative) 166 | if float(delta[0]) != 0: 167 | self._moments[name] = delta 168 | 169 | def _get_delta(self, name): 170 | r"""Returns the raw moments that were accumulated for the given 171 | statistic between the last two calls to `update()`, or zero if 172 | no scalars were collected. 173 | """ 174 | assert self._regex.fullmatch(name) 175 | if name not in self._moments: 176 | self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype) 177 | return self._moments[name] 178 | 179 | def num(self, name): 180 | r"""Returns the number of scalars that were accumulated for the given 181 | statistic between the last two calls to `update()`, or zero if 182 | no scalars were collected. 183 | """ 184 | delta = self._get_delta(name) 185 | return int(delta[0]) 186 | 187 | def mean(self, name): 188 | r"""Returns the mean of the scalars that were accumulated for the 189 | given statistic between the last two calls to `update()`, or NaN if 190 | no scalars were collected. 191 | """ 192 | delta = self._get_delta(name) 193 | if int(delta[0]) == 0: 194 | return float('nan') 195 | return float(delta[1] / delta[0]) 196 | 197 | def std(self, name): 198 | r"""Returns the standard deviation of the scalars that were 199 | accumulated for the given statistic between the last two calls to 200 | `update()`, or NaN if no scalars were collected. 201 | """ 202 | delta = self._get_delta(name) 203 | if int(delta[0]) == 0 or not np.isfinite(float(delta[1])): 204 | return float('nan') 205 | if int(delta[0]) == 1: 206 | return float(0) 207 | mean = float(delta[1] / delta[0]) 208 | raw_var = float(delta[2] / delta[0]) 209 | return np.sqrt(max(raw_var - np.square(mean), 0)) 210 | 211 | def as_dict(self): 212 | r"""Returns the averages accumulated between the last two calls to 213 | `update()` as an `dnnlib.EasyDict`. The contents are as follows: 214 | 215 | dnnlib.EasyDict( 216 | NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT), 217 | ... 218 | ) 219 | """ 220 | stats = dnnlib.EasyDict() 221 | for name in self.names(): 222 | stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name)) 223 | return stats 224 | 225 | def __getitem__(self, name): 226 | r"""Convenience getter. 227 | `collector[name]` is a synonym for `collector.mean(name)`. 228 | """ 229 | return self.mean(name) 230 | 231 | #---------------------------------------------------------------------------- 232 | 233 | def _sync(names): 234 | r"""Synchronize the global cumulative counters across devices and 235 | processes. Called internally by `Collector.update()`. 236 | """ 237 | if len(names) == 0: 238 | return [] 239 | global _sync_called 240 | _sync_called = True 241 | 242 | # Collect deltas within current rank. 243 | deltas = [] 244 | device = _sync_device if _sync_device is not None else torch.device('cpu') 245 | for name in names: 246 | delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device) 247 | for counter in _counters[name].values(): 248 | delta.add_(counter.to(device)) 249 | counter.copy_(torch.zeros_like(counter)) 250 | deltas.append(delta) 251 | deltas = torch.stack(deltas) 252 | 253 | # Sum deltas across ranks. 254 | if _sync_device is not None: 255 | torch.distributed.all_reduce(deltas) 256 | 257 | # Update cumulative values. 258 | deltas = deltas.cpu() 259 | for idx, name in enumerate(names): 260 | if name not in _cumulative: 261 | _cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) 262 | _cumulative[name].add_(deltas[idx]) 263 | 264 | # Return name-value pairs. 265 | return [(name, _cumulative[name]) for name in names] 266 | 267 | #---------------------------------------------------------------------------- 268 | # Convenience. 269 | 270 | default_collector = Collector() 271 | 272 | #---------------------------------------------------------------------------- 273 | --------------------------------------------------------------------------------