├── .gitignore
├── README.md
├── app.py
├── assets
├── beauty_makeup.txt
├── gradio.png
├── mt_makeup.txt
└── nomakeup.txt
├── config.py
├── configs
├── base.yaml
├── makeup.yaml
├── model_256_256.yaml
├── model_256_256_eval_only.yaml
├── model_64_64.yaml
└── non_makeup.yaml
├── data
└── mt_text_anno.json
├── dataset
├── __init__.py
├── makeup_dataset.py
├── prefetcher.py
└── text_dataset.py
├── dnnlib
├── __init__.py
└── util.py
├── docs
├── index.html
└── static
│ ├── css
│ ├── bulma-carousel.min.css
│ ├── bulma-slider.min.css
│ ├── bulma.css.map.txt
│ ├── bulma.min.css
│ ├── fontawesome.all.min.css
│ └── index.css
│ ├── images
│ ├── all_in_one_vis.svg
│ ├── makeup_to_non.svg
│ ├── non_to_makeup.svg
│ ├── pipeline.svg
│ ├── scale_comp_combine.svg
│ ├── teaser.svg
│ ├── text-example.svg
│ └── text_modification.svg
│ ├── js
│ ├── bulma-carousel.js
│ ├── bulma-carousel.min.js
│ ├── bulma-slider.js
│ ├── bulma-slider.min.js
│ ├── fontawesome.all.min.js
│ └── index.js
│ └── poster.pdf
├── generate_text_editing.py
├── generate_transfer.py
├── generate_translation.py
├── main.py
├── misc
├── __init__.py
├── compute_classification.py
├── compute_metrics.py
├── compute_removal_metrics.py
├── compute_text_metrics.py
├── constant.py
├── convert_beauty_face.py
├── meter.py
└── morphing.py
├── modeling
├── __init__.py
├── generate.py
├── loss.py
├── model.py
├── scheduler.py
├── text_translation.py
└── translation.py
├── requirements.txt
├── script
└── train_text_to_image.sh
├── sd_training
└── train_text_to_image.py
├── torch_fidelity
├── __init__.py
├── datasets.py
├── defaults.py
├── deprecations.py
├── feature_extractor_base.py
├── feature_extractor_clip.py
├── feature_extractor_dinov2.py
├── feature_extractor_inceptionv3.py
├── feature_extractor_vgg16.py
├── fidelity.py
├── generative_model_base.py
├── generative_model_modulewrapper.py
├── generative_model_onnx.py
├── helpers.py
├── interpolate_compat_tensorflow.py
├── metric_fid.py
├── metric_isc.py
├── metric_kid.py
├── metric_ppl.py
├── metric_prc.py
├── metrics.py
├── noise.py
├── registry.py
├── sample_similarity_base.py
├── sample_similarity_lpips.py
├── utils.py
├── utils_torch.py
├── utils_torchvision.py
└── version.py
└── torch_utils
├── __init__.py
├── distributed.py
├── misc.py
├── persistence.py
├── torch_dct.py
└── training_stats.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | .huskyrc.json
3 | out
4 | log.log
5 | **/node_modules
6 | *.pyc
7 | *.vsix
8 | envVars.txt
9 | **/.vscode/.ropeproject/**
10 | **/testFiles/**/.cache/**
11 | *.noseids
12 | .nyc_output
13 | .vscode-test
14 | __pycache__
15 | npm-debug.log
16 | **/.mypy_cache/**
17 | !yarn.lock
18 | coverage/
19 | cucumber-report.json
20 | **/.vscode-test/**
21 | **/.vscode test/**
22 | **/.vscode-smoke/**
23 | **/.venv*/
24 | port.txt
25 | precommit.hook
26 | pythonFiles/lib/**
27 | pythonFiles/get-pip.py
28 | debug_coverage*/**
29 | languageServer/**
30 | languageServer.*/**
31 | bin/**
32 | obj/**
33 | .pytest_cache
34 | tmp/**
35 | .python-version
36 | .vs/
37 | test-results*.xml
38 | xunit-test-results.xml
39 | build/ci/performance/performance-results.json
40 | !build/
41 | debug*.log
42 | debugpy*.log
43 | pydevd*.log
44 | nodeLanguageServer/**
45 | nodeLanguageServer.*/**
46 | dist/**
47 | # translation files
48 | *.xlf
49 | package.nls.*.json
50 | l10n/
51 | runs/
52 | generate/
53 | wandb/
54 | .env
55 | *.ipynb
56 | generate_outputs
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
MAD: Makeup All-in-One with Cross-Domain Diffusion Model
4 |
5 |
A unified cross-domain diffusion model for various makeup tasks
6 |
7 |

8 |
9 |
10 |
11 |
12 |

13 |
14 |
15 |
16 | > Bo-Kai Ruan, Hong-Han Shuai
17 | >
18 | > * Contact: Bo-Kai Ruan
19 | > * [arXiv paper](https://arxiv.org/abs/2504.02545) | [Project Website](https://basiclab.github.io/MAD)
20 |
21 | ## 🚀 A. Installation
22 |
23 | ### Step 1: Create Environment
24 |
25 | * Ubuntu 22.04 with Python ≥ 3.10 (tested with GPU using CUDA 11.8)
26 |
27 | ```shell
28 | conda create --name mad python=3.10 -y
29 | conda activate mad
30 | ```
31 |
32 | ### Step 2: Install Dependencies
33 |
34 | ```shell
35 | conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.8 -c pytorch -c nvidia -y
36 | conda install xformers -c xformers -y
37 | pip install -r requirements.txt
38 |
39 | # Weights for landmarks
40 | wget http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2
41 | bzip2 -d shape_predictor_68_face_landmarks.dat.bz2 && mkdir weights && mv shape_predictor_68_face_landmarks.dat weights
42 | ```
43 |
44 | ### Step 3: Prepare the Dataset
45 |
46 | The following table provides download links for the datasets:
47 |
48 | | Dataset | Link |
49 | | ------------------ | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
50 | | MT Dataset | [all](https://drive.google.com/file/d/18UlvYDL6UGZ2rs0yaDsSzoUlw8KI5ABY/view) |
51 | | BeautyFace Dataset | [images](https://drive.google.com/file/d/1mhoopmi7OlsClOuKocjldGbTYnyDzNMc/view?usp=sharing), [parsing map](https://drive.google.com/file/d/1WgadvcV1pUtEMCYxjwWBledEQfDbadn7/view?usp=sharing) |
52 |
53 | We recommend unzipping and placing the datasets in the same folder with the following structure:
54 |
55 | ```plaintext
56 | 📦 data
57 | ┣ 📂 mtdataset
58 | ┃ ┣ 📂 images
59 | ┃ ┃ ┣ 📂 makeup
60 | ┃ ┃ ┗ 📂 non-makeup
61 | ┃ ┣ 📂 parsing
62 | ┃ ┃ ┣ 📂 makeup
63 | ┃ ┃ ┗ 📂 non-makeup
64 | ┣ 📂 beautyface
65 | ┃ ┣ 📂 images
66 | ┃ ┗ 📂 parsing
67 | ┗ ...
68 | ```
69 |
70 | Run `misc/convert_beauty_face.py` to convert the parsing maps for the BeautyFace dataset:
71 |
72 | ```shell
73 | python misc/convert_beauty_face.py --original data/beautyface/parsing --output data/beautyface/parsing
74 | ```
75 |
76 | We also provide the labeling text dataset [here](data/mt_text_anno.json).
77 |
78 | ## 📦 B. Usage
79 |
80 | The pretrained weight is uploaded to [Hugging Face](https://huggingface.co/Justin900/MAD) 🤗.
81 |
82 | ### B.1 Training a Model
83 |
84 | * With our model
85 |
86 | ```shell
87 | # Single GPU
88 | python main.py --config configs/model_256_256.yaml
89 |
90 | # Multi-GPU
91 | accelerate launch --multi_gpu --num_processes={NUM_OF_GPU} main.py --config configs/model_256_256.yaml
92 | ```
93 |
94 | * With stable diffusion
95 |
96 | ```shell
97 | ./script/train_text_to_image.sh
98 | ```
99 |
100 | ### B.2 Beauty Filter or Makeup Removal
101 |
102 | To use the beauty filter or perform makeup removal, create a `.txt` file listing the images. Here's an example:
103 |
104 | ```plaintext
105 | makeup/xxxx1.jpg
106 | makeup/xxxx2.jpg
107 | ```
108 |
109 | Use the `source-label` and `target-label` arguments to choose between beauty filtering or makeup removal. `0` is for makeup images and `1` is for non-makeup images.
110 |
111 | For makeup removal:
112 |
113 | ```shell
114 | python generate_translation.py \
115 | --config configs/model_256_256.yaml \
116 | --save-folder removal_results \
117 | --source-root data/mtdataset/images \
118 | --source-list assets/mt_makeup.txt \
119 | --source-label 0 \
120 | --target-label 1 \
121 | --num-process {NUM_PROCESS} \
122 | --opts MODEL.PRETRAINED Justin900/MAD
123 | ```
124 |
125 | ### B.3 Makeup Transfer
126 |
127 | For makeup transfer, prepare two `.txt` files: one for source images and one for target images. Example:
128 |
129 | ```plaintext
130 | # File 1 | # File 2
131 | makeup/xxxx1.jpg | non-makeup/xxxx1.jpg
132 | makeup/xxxx2.jpg | non-makeup/xxxx2.jpg
133 | ... | ...
134 | ```
135 |
136 | To apply makeup transfer:
137 |
138 | ```shell
139 | python generate_transfer.py \
140 | --config configs/model_256_256.yaml \
141 | --save-folder transfer_result \
142 | --source-root data/mtdataset/images \
143 | --target-root data/beautyface/images \
144 | --source-list assets/nomakeup.txt \
145 | --target-list assets/beauty_makeup.txt \
146 | --source-label 1 \
147 | --target-label 0 \
148 | --num-process {NUM_PROCESS} \
149 | --inpainting \
150 | --cam \
151 | --opts MODEL.PRETRAINED Justin900/MAD
152 | ```
153 |
154 | ### B.4 Text Modification
155 |
156 | For text modification, prepare a JSON file:
157 |
158 | ```
159 | [
160 | {"image": "xxx.jpg", "style": "makeup with xxx"}
161 | ...
162 | ]
163 | ```
164 |
165 | ```shell
166 | python generate_text_editing.py \
167 | --save-folder text_editing_results \
168 | --source-root data/mtdataset/images \
169 | --source-list assets/text_editing.json \
170 | --num-process {NUM_PROCESS} \
171 | --model-path Justin900/MAD
172 | ```
173 |
174 | ## 🎨 C. Web UI
175 |
176 | Users can start the web UI by add access our designed UI with [gradio](https://github.com/gradio-app/gradio) from `localhost:7860`:
177 |
178 | ```
179 | python app.py
180 | ```
181 |
182 | 
183 |
184 | ## Citation
185 |
186 | ```bibtex
187 | @article{ruan2025mad,
188 | title={MAD: Makeup All-in-One with Cross-Domain Diffusion Model},
189 | author={Ruan, Bo-Kai and Shuai, Hong-Han},
190 | journal={arXiv preprint arXiv:2504.02545},
191 | year={2025}
192 | }
193 | ```
194 |
--------------------------------------------------------------------------------
/assets/beauty_makeup.txt:
--------------------------------------------------------------------------------
1 | 3_49.png
2 | 709_20_09.png
3 |
--------------------------------------------------------------------------------
/assets/gradio.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basiclab/MAD/f8b0b012eb3886274df41da9dee0eed6e07e3a34/assets/gradio.png
--------------------------------------------------------------------------------
/assets/mt_makeup.txt:
--------------------------------------------------------------------------------
1 | makeup/vHX44.png
2 | makeup/vFG436.png
3 |
--------------------------------------------------------------------------------
/assets/nomakeup.txt:
--------------------------------------------------------------------------------
1 | non-makeup/xfsy_0458.png
2 | non-makeup/vSYYZ388.png
3 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import pprint
3 |
4 | from colorama import Fore, Style
5 | from tabulate import tabulate
6 | from yacs.config import CfgNode as CN
7 |
8 |
9 | def create_cfg():
10 | cfg = CN()
11 | cfg._BASE_ = None
12 | cfg.PROJECT_NAME = "Makeup Transfer with Diffusion"
13 | cfg.PROJECT_DIR = None
14 |
15 | # ##### Model setup #####
16 | cfg.MODEL = CN()
17 | cfg.MODEL.IN_CHANNELS = 3
18 | cfg.MODEL.OUT_CHANNELS = cfg.MODEL.IN_CHANNELS
19 | cfg.MODEL.LAYERS_PER_BLOCK = 2
20 | cfg.MODEL.BASE_DIM = 128
21 | cfg.MODEL.LAYER_SCALE = [1, 1, 2, 2, 4, 4]
22 | cfg.MODEL.PRETRAINED = None # "pretrained/pretrained.pkl"
23 | cfg.MODEL.LABEL_DIM = 0
24 | cfg.MODEL.DOWN_BLOCK_TYPE = (["CrossAttnDownBlock2D"] * (len(cfg.MODEL.LAYER_SCALE) - 1)) + [
25 | "DownBlock2D"
26 | ]
27 | cfg.MODEL.UP_BLOCK_TYPE = (["CrossAttnUpBlock2D"] * (len(cfg.MODEL.LAYER_SCALE) - 1)) + [
28 | "UpBlock2D"
29 | ]
30 |
31 | # ###### Training set ######
32 | cfg.TRAIN = CN()
33 | cfg.TRAIN.MAKEUP = None
34 |
35 | # Log and save
36 | cfg.TRAIN.RESUME = None
37 | cfg.TRAIN.IMAGE_SIZE = 256
38 | cfg.TRAIN.LOG_INTERVAL = 20
39 | cfg.TRAIN.SAVE_INTERVAL = 10000
40 | cfg.TRAIN.SAMPLE_INTERVAL = 10000
41 | cfg.TRAIN.ROOT = None
42 | cfg.TRAIN.TEXT_LABEL_PATH = None
43 |
44 | # Training iteration
45 | cfg.TRAIN.BATCH_SIZE = 2
46 | cfg.TRAIN.NUM_WORKERS = 4
47 | cfg.TRAIN.MAX_ITER = 350000
48 | cfg.TRAIN.GRADIENT_ACCUMULATION_STEPS = 16
49 | cfg.TRAIN.MIXED_PRECISION = "no"
50 | cfg.TRAIN.GRAD_NORM = 1.0
51 |
52 | # EMA setup
53 | cfg.TRAIN.EMA_MAX_DECAY = 0.9999
54 | cfg.TRAIN.EMA_INV_GAMMA = 1.0
55 | cfg.TRAIN.EMA_POWER = 0.75
56 |
57 | # Optimizer
58 | cfg.TRAIN.LR = 0.0001
59 | cfg.TRAIN.LR_WARMUP = 1000
60 |
61 | # Diffusion setup
62 | cfg.TRAIN.TIME_STEPS = 1000
63 | cfg.TRAIN.SAMPLE_STEPS = cfg.TRAIN.TIME_STEPS
64 | cfg.TRAIN.NOISE_SCHEDULER = CN()
65 | # ///// for linear start \\\\\\\
66 | cfg.TRAIN.NOISE_SCHEDULER.BETA_START = 1e-4
67 | cfg.TRAIN.NOISE_SCHEDULER.BETA_END = 0.02
68 | # ///// for linear end \\\\\\\
69 | cfg.TRAIN.NOISE_SCHEDULER.TYPE = "linear"
70 | cfg.TRAIN.NOISE_SCHEDULER.PRED_TYPE = "epsilon"
71 |
72 | # ======= Evaluation set =======
73 | cfg.EVAL = CN()
74 | cfg.EVAL.BATCH_SIZE = 4
75 | cfg.EVAL.SAMPLE_STEPS = 1000
76 | cfg.EVAL.ETA = 0.01
77 | cfg.EVAL.REFINE_STEPS = 0
78 | cfg.EVAL.REFINE_ITERATIONS = 0
79 | cfg.EVAL.SCHEDULER = "ddpm"
80 |
81 | return cfg
82 |
83 |
84 | def merge_possible_with_base(cfg: CN, config_path):
85 | with open(config_path, "r") as f:
86 | new_cfg = cfg.load_cfg(f)
87 | if "_BASE_" in new_cfg:
88 | cfg.merge_from_file(osp.join(osp.dirname(config_path), new_cfg._BASE_))
89 | cfg.merge_from_other_cfg(new_cfg)
90 |
91 |
92 | def split_into(v):
93 | res = "(\n"
94 | for item in v:
95 | res += f" {item},\n"
96 | res += ")"
97 | return res
98 |
99 |
100 | def pretty_print_cfg(cfg):
101 | def _indent(s_, num_spaces):
102 | s = s_.split("\n")
103 | if len(s) == 1:
104 | return s_
105 | first = s.pop(0)
106 | s = [(num_spaces * " ") + line for line in s]
107 | s = "\n".join(s)
108 | s = first + "\n" + s
109 | return s
110 |
111 | r = ""
112 | s = []
113 | for k, v in sorted(cfg.items()):
114 | seperator = "\n" if isinstance(v, CN) else " "
115 | attr_str = "{}:{}{}".format(
116 | str(k),
117 | seperator,
118 | pretty_print_cfg(v) if isinstance(v, CN) else pprint.pformat(v),
119 | )
120 | attr_str = _indent(attr_str, 2)
121 | s.append(attr_str)
122 | r += "\n".join(s)
123 | return r
124 |
125 |
126 | def show_config(cfg):
127 | table = tabulate(
128 | {"Configuration": [pretty_print_cfg(cfg)]}, headers="keys", tablefmt="fancy_grid"
129 | )
130 | print(f"{Fore.BLUE}", end="")
131 | print(table)
132 | print(f"{Style.RESET_ALL}", end="")
133 |
--------------------------------------------------------------------------------
/configs/base.yaml:
--------------------------------------------------------------------------------
1 | TRAIN:
2 | ROOT: ["data/mtdataset/images"]
3 | TEXT_LABEL_PATH: "data/mt_text_anno.json"
4 | MIXED_PRECISION: "fp16"
--------------------------------------------------------------------------------
/configs/makeup.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: base.yaml
2 | PROJECT_DIR: runs/makeup
3 | TRAIN:
4 | MAKEUP: true
--------------------------------------------------------------------------------
/configs/model_256_256.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: base.yaml
2 | PROJECT_DIR: runs/mixup_256_256
3 | MODEL:
4 | LABEL_DIM: 2
5 | TRAIN:
6 | BATCH_SIZE: 4
7 | GRADIENT_ACCUMULATION_STEPS: 8
8 | MAX_ITER: 700000
--------------------------------------------------------------------------------
/configs/model_256_256_eval_only.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: base.yaml
2 | PROJECT_DIR: runs/for_eval_only
3 | MODEL:
4 | LABEL_DIM: 2
5 | TRAIN:
6 | ROOT: ["data/mtdataset/images", "data/beautyface/images"]
7 | BATCH_SIZE: 4
8 | GRADIENT_ACCUMULATION_STEPS: 8
9 | MAX_ITER: 350000
--------------------------------------------------------------------------------
/configs/model_64_64.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: base.yaml
2 | PROJECT_DIR: runs/mixup_64_64
3 | TRAIN:
4 | IMAGE_SIZE: 64
5 | BATCH_SIZE: 16
6 | GRADIENT_ACCUMULATION_STEPS: 4
7 | MAX_ITER: 700000
8 | MODEL:
9 | BASE_DIM: 192
10 | LAYER_SCALE: [1, 2, 3, 4]
11 | LAYERS_PER_BLOCK: 3
12 | LABEL_DIM: 2
13 |
--------------------------------------------------------------------------------
/configs/non_makeup.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: base.yaml
2 | PROJECT_DIR: runs/non_makeup
3 | TRAIN:
4 | MAKEUP: false
--------------------------------------------------------------------------------
/dataset/__init__.py:
--------------------------------------------------------------------------------
1 | from .makeup_dataset import get_makeup_loader
2 | from .prefetcher import DataPrefetcher
3 |
4 | __all__ = ["DataPrefetcher", "get_makeup_loader"]
5 |
--------------------------------------------------------------------------------
/dataset/makeup_dataset.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import json
3 | import os
4 | import os.path as osp
5 | import random
6 |
7 | import torch
8 | from PIL import Image, ImageFile
9 | from sklearn.model_selection import train_test_split
10 | from transformers import CLIPTokenizer
11 |
12 | ImageFile.LOAD_TRUNCATED_IMAGES = True
13 |
14 | COMP_TO_TEXT = {
15 | "face": "skin",
16 | "eye": "eyes",
17 | }
18 |
19 | CHECK_TEXT = {
20 | "face": "skin",
21 | "lips": "lip",
22 | "eyes": "eye",
23 | }
24 |
25 |
26 | class MakeupDataset:
27 | def __init__(
28 | self,
29 | root_list,
30 | text_label_path,
31 | use_label=False,
32 | train=True,
33 | makeup=False,
34 | transforms=None,
35 | ):
36 | # Important!!!!!!!!!!
37 | # Adding beautyface or wild to `root_list` is only allowed for training a removal model for evaluation.
38 | # The removal model should not be used to generate the result, which may lead to unfair comparison.
39 |
40 | if not isinstance(root_list, list):
41 | root_list = [root_list]
42 |
43 | self.img_list = []
44 | self.use_label = use_label
45 | self.label_to_idx = {}
46 |
47 | with open(text_label_path, "r") as f:
48 | self.text_label = json.load(f)
49 |
50 | self.tokenizer = CLIPTokenizer.from_pretrained(
51 | "CompVis/stable-diffusion-v1-4", subfolder="tokenizer"
52 | )
53 |
54 | if use_label:
55 | for root in root_list:
56 | for catego in ["", "makeup", "non-makeup"]:
57 | if not osp.exists(osp.join(root, catego)):
58 | continue
59 | self.img_list.extend(
60 | list(glob.glob(osp.join(root, catego, "*.png")))
61 | + list(glob.glob(osp.join(root, catego, "*.jpg")))
62 | )
63 | # 0: makeup, 1: non-makeup
64 | for img in self.img_list:
65 | folder_name = osp.basename(osp.dirname(img))
66 | if folder_name == "non-makeup":
67 | self.label_to_idx[folder_name] = 1
68 | else:
69 | self.label_to_idx[folder_name] = 0
70 | else:
71 | for root in root_list:
72 | img_dir = "makeup" if makeup else "non-makeup"
73 | if not osp.exists(osp.join(root, img_dir)):
74 | continue
75 | self.img_list.extend(
76 | list(glob.glob(osp.join(root, img_dir, "*.png")))
77 | + list(glob.glob(osp.join(root, img_dir, "*.jpg")))
78 | )
79 | self.img_list.sort() # ensure the order is the same on different machines
80 |
81 | # Create label list
82 | train_split, val_split = train_test_split(self.img_list, test_size=0.1, random_state=42)
83 | self.img_list = train_split if train else val_split
84 | self.transforms = transforms
85 |
86 | def __len__(self):
87 | return len(self.img_list)
88 |
89 | def __getitem__(self, idx):
90 | img_name = self.img_list[idx]
91 | img = Image.open(img_name).convert("RGB")
92 | if self.transforms:
93 | img = self.transforms(img)
94 |
95 | folder_name = osp.basename(osp.dirname(img_name))
96 | label = self.label_to_idx[folder_name] if self.use_label else None
97 |
98 | if random.random() > 0.7:
99 | if os.path.basename(img_name) in self.text_label:
100 | all_desps = self.text_label[os.path.basename(img_name)]
101 | desp_list = []
102 | for comp_name, desp in all_desps.items():
103 | out = random.choice(desp).strip().lower()
104 | if CHECK_TEXT.get(comp_name, comp_name) not in out:
105 | out = f"{out} {COMP_TO_TEXT.get(comp_name, comp_name)}"
106 | desp_list.append(out)
107 | random.shuffle(desp_list)
108 | desp = "makeup with " + ", ".join(desp_list)
109 | else:
110 | desp = "no or light makeup"
111 | else:
112 | desp = ""
113 |
114 | text_inputs = self.tokenizer(
115 | desp,
116 | padding="max_length",
117 | max_length=self.tokenizer.model_max_length,
118 | truncation=True,
119 | return_tensors="pt",
120 | )
121 | text_input_ids = text_inputs.input_ids
122 | return {
123 | "image": img,
124 | "text": text_input_ids,
125 | "label": (
126 | torch.nn.functional.one_hot(torch.LongTensor([label]), len(self.label_to_idx))
127 | .squeeze(0)
128 | .float()
129 | if self.use_label
130 | else None
131 | ),
132 | }
133 |
134 |
135 | def get_makeup_loader(cfg, train, transforms):
136 | dataset = MakeupDataset(
137 | cfg.TRAIN.ROOT,
138 | cfg.TRAIN.TEXT_LABEL_PATH,
139 | use_label=cfg.MODEL.LABEL_DIM > 0,
140 | train=train,
141 | makeup=cfg.TRAIN.MAKEUP,
142 | transforms=transforms,
143 | )
144 | return torch.utils.data.DataLoader(
145 | dataset,
146 | shuffle=train,
147 | batch_size=cfg.TRAIN.BATCH_SIZE,
148 | num_workers=cfg.TRAIN.NUM_WORKERS,
149 | pin_memory=True,
150 | drop_last=True,
151 | )
152 |
--------------------------------------------------------------------------------
/dataset/prefetcher.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class DataPrefetcher:
5 | def __init__(self, loader, device):
6 | self.origin_loader = loader
7 | self.loader = iter(loader)
8 | self.stream = torch.cuda.Stream(device=self.device)
9 | self.use_gpu = torch.cuda.is_available()
10 | self.device = device
11 | self.preload()
12 |
13 | def preload(self):
14 | try:
15 | self.batch = next(self.loader)
16 | except StopIteration:
17 | self.loader = iter(self.origin_loader)
18 | self.batch = next(self.loader)
19 | if self.use_gpu:
20 | with torch.cuda.stream(self.stream):
21 | self.batch = self.batch.to(device=self.device, non_blocking=True)
22 |
23 | def next(self):
24 | if self.use_gpu:
25 | torch.cuda.current_stream().wait_stream(self.stream)
26 | batch = self.batch
27 | self.preload()
28 | return batch
29 |
--------------------------------------------------------------------------------
/dataset/text_dataset.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import json
3 | import os
4 | import random
5 |
6 | import torch
7 | from PIL import Image
8 |
9 | COMP_TO_TEXT = {
10 | "face": "skin",
11 | "eye": "eyes",
12 | }
13 |
14 | CHECK_TEXT = {
15 | "face": "skin",
16 | "lips": "lip",
17 | "eyes": "eye",
18 | }
19 |
20 |
21 | class MTTextDataset(torch.utils.data.Dataset):
22 | def __init__(self, root, json_file, tokenizer, train=True, transforms=None):
23 | self.data = (
24 | list(glob.glob(os.path.join(root, "makeup/*.png")))
25 | + list(glob.glob(os.path.join(root, "makeup/*.jpg")))
26 | + list(glob.glob(os.path.join(root, "no_makeup/*.png")))
27 | + list(glob.glob(os.path.join(root, "no_makeup/*.jpg")))
28 | )
29 | self.train = train
30 | self.transforms = transforms
31 | self.tokenizer = tokenizer
32 | with open(json_file, "r") as f:
33 | self.desp = json.load(f)
34 |
35 | def __len__(self):
36 | return len(self.data)
37 |
38 | def __getitem__(self, idx):
39 | img_name = self.data[idx]
40 | img = Image.open(img_name).convert("RGB")
41 | img = self.transforms(img)
42 |
43 | if os.path.basename(img_name) in self.desp:
44 | all_desps = self.desp[os.path.basename(img_name)]
45 | desp_list = []
46 | for comp_name, desp in all_desps.items():
47 | out = random.choice(desp).strip().lower()
48 | if CHECK_TEXT.get(comp_name, comp_name) not in out:
49 | out = f"{out} {COMP_TO_TEXT.get(comp_name, comp_name)}"
50 | desp_list.append(out)
51 | random.shuffle(desp_list)
52 | desp = "makeup with " + ", ".join(desp_list)
53 | else:
54 | desp = "no or light makeup"
55 |
56 | desp = self.tokenizer(
57 | desp,
58 | max_length=self.tokenizer.model_max_length,
59 | padding="max_length",
60 | truncation=True,
61 | return_tensors="pt",
62 | ).input_ids
63 |
64 | return {"pixel_values": img, "input_ids": desp}
65 |
--------------------------------------------------------------------------------
/dnnlib/__init__.py:
--------------------------------------------------------------------------------
1 | from .util import EasyDict, make_cache_dir_path
2 |
3 | __all__ = ["EasyDict", "make_cache_dir_path"]
4 |
--------------------------------------------------------------------------------
/docs/static/css/bulma-carousel.min.css:
--------------------------------------------------------------------------------
1 | @-webkit-keyframes spinAround{from{-webkit-transform:rotate(0);transform:rotate(0)}to{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}@keyframes spinAround{from{-webkit-transform:rotate(0);transform:rotate(0)}to{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}.slider{position:relative;width:100%}.slider-container{display:flex;flex-wrap:nowrap;flex-direction:row;overflow:hidden;-webkit-transform:translate3d(0,0,0);transform:translate3d(0,0,0);min-height:100%}.slider-container.is-vertical{flex-direction:column}.slider-container .slider-item{flex:none}.slider-container .slider-item .image.is-covered img{-o-object-fit:cover;object-fit:cover;-o-object-position:center center;object-position:center center;height:100%;width:100%}.slider-container .slider-item .video-container{height:0;padding-bottom:0;padding-top:56.25%;margin:0;position:relative}.slider-container .slider-item .video-container.is-1by1,.slider-container .slider-item .video-container.is-square{padding-top:100%}.slider-container .slider-item .video-container.is-4by3{padding-top:75%}.slider-container .slider-item .video-container.is-21by9{padding-top:42.857143%}.slider-container .slider-item .video-container embed,.slider-container .slider-item .video-container iframe,.slider-container .slider-item .video-container object{position:absolute;top:0;left:0;width:100%!important;height:100%!important}.slider-navigation-next,.slider-navigation-previous{display:flex;justify-content:center;align-items:center;position:absolute;width:42px;height:42px;background:#fff center center no-repeat;background-size:20px 20px;border:1px solid #fff;border-radius:25091983px;box-shadow:0 2px 5px #3232321a;top:50%;margin-top:-20px;left:0;cursor:pointer;transition:opacity .3s,-webkit-transform .3s;transition:transform .3s,opacity .3s;transition:transform .3s,opacity .3s,-webkit-transform .3s}.slider-navigation-next:hover,.slider-navigation-previous:hover{-webkit-transform:scale(1.2);transform:scale(1.2)}.slider-navigation-next.is-hidden,.slider-navigation-previous.is-hidden{display:none;opacity:0}.slider-navigation-next svg,.slider-navigation-previous svg{width:25%}.slider-navigation-next{left:auto;right:0;background:#fff center center no-repeat;background-size:20px 20px}.slider-pagination{display:none;justify-content:center;align-items:center;position:absolute;bottom:0;left:0;right:0;padding:.5rem 1rem;text-align:center}.slider-pagination .slider-page{background:#fff;width:10px;height:10px;border-radius:25091983px;display:inline-block;margin:0 3px;box-shadow:0 2px 5px #3232321a;transition:-webkit-transform .3s;transition:transform .3s;transition:transform .3s,-webkit-transform .3s;cursor:pointer}.slider-pagination .slider-page.is-active,.slider-pagination .slider-page:hover{-webkit-transform:scale(1.4);transform:scale(1.4)}@media screen and (min-width:800px){.slider-pagination{display:flex}}.hero.has-carousel{position:relative}.hero.has-carousel+.hero-body,.hero.has-carousel+.hero-footer,.hero.has-carousel+.hero-head{z-index:10;overflow:hidden}.hero.has-carousel .hero-carousel{position:absolute;top:0;left:0;bottom:0;right:0;height:auto;border:none;margin:auto;padding:0;z-index:0}.hero.has-carousel .hero-carousel .slider{width:100%;max-width:100%;overflow:hidden;height:100%!important;max-height:100%;z-index:0}.hero.has-carousel .hero-carousel .slider .has-background{max-height:100%}.hero.has-carousel .hero-carousel .slider .has-background .is-background{-o-object-fit:cover;object-fit:cover;-o-object-position:center center;object-position:center center;height:100%;width:100%}.hero.has-carousel .hero-body{margin:0 3rem;z-index:10}
--------------------------------------------------------------------------------
/docs/static/css/index.css:
--------------------------------------------------------------------------------
1 | body {
2 | font-family: 'Noto Sans', sans-serif;
3 | }
4 |
5 |
6 | .footer .icon-link {
7 | font-size: 25px;
8 | color: #000;
9 | }
10 |
11 | .footer {
12 | padding-bottom: 24px;
13 | }
14 |
15 | .link-block a {
16 | margin-top: 5px;
17 | margin-bottom: 5px;
18 | }
19 |
20 | .dnerf {
21 | font-variant: small-caps;
22 | }
23 |
24 |
25 | .teaser .hero-body {
26 | padding-top: 0;
27 | padding-bottom: 3rem;
28 | }
29 |
30 | .teaser {
31 | font-family: 'Google Sans', sans-serif;
32 | }
33 |
34 |
35 | .publication-title {
36 | }
37 |
38 | .publication-banner {
39 | max-height: parent;
40 |
41 | }
42 |
43 | .publication-banner video {
44 | position: relative;
45 | left: auto;
46 | top: auto;
47 | transform: none;
48 | object-fit: fit;
49 | }
50 |
51 | .publication-header .hero-body {
52 | }
53 |
54 | .publication-title {
55 | font-family: 'Google Sans', sans-serif;
56 | }
57 |
58 | .publication-authors {
59 | font-family: 'Google Sans', sans-serif;
60 | }
61 |
62 | .publication-venue {
63 | color: #555;
64 | width: fit-content;
65 | font-weight: bold;
66 | }
67 |
68 | .publication-awards {
69 | color: #ff3860;
70 | width: fit-content;
71 | font-weight: bolder;
72 | }
73 |
74 | .publication-authors {
75 | }
76 |
77 | .publication-authors a {
78 | color: hsl(204, 86%, 53%) !important;
79 | }
80 |
81 | .publication-authors a:hover {
82 | text-decoration: underline;
83 | }
84 |
85 | .author-block {
86 | display: inline-block;
87 | }
88 |
89 | .publication-banner img {
90 | }
91 |
92 | .publication-authors {
93 | /*color: #4286f4;*/
94 | }
95 |
96 | .publication-video {
97 | position: relative;
98 | width: 100%;
99 | height: 0;
100 | padding-bottom: 56.25%;
101 |
102 | overflow: hidden;
103 | border-radius: 10px !important;
104 | }
105 |
106 | .publication-video iframe {
107 | position: absolute;
108 | top: 0;
109 | left: 0;
110 | width: 100%;
111 | height: 100%;
112 | }
113 |
114 | .publication-body img {
115 | }
116 |
117 | .results-carousel {
118 | overflow: hidden;
119 | }
120 |
121 | .results-carousel .item {
122 | margin: 5px;
123 | overflow: hidden;
124 | border: 1px solid #bbb;
125 | border-radius: 10px;
126 | padding: 0;
127 | font-size: 0;
128 | }
129 |
130 | .results-carousel video {
131 | margin: 0;
132 | }
133 |
134 |
135 | .interpolation-panel {
136 | background: #f5f5f5;
137 | border-radius: 10px;
138 | }
139 |
140 | .interpolation-panel .interpolation-image {
141 | width: 100%;
142 | border-radius: 5px;
143 | }
144 |
145 | .interpolation-video-column {
146 | }
147 |
148 | .interpolation-panel .slider {
149 | margin: 0 !important;
150 | }
151 |
152 | .interpolation-panel .slider {
153 | margin: 0 !important;
154 | }
155 |
156 | #interpolation-image-wrapper {
157 | width: 100%;
158 | }
159 | #interpolation-image-wrapper img {
160 | border-radius: 5px;
161 | }
162 |
163 | td:nth-child(1), th[colspan]:nth-child(2), th[colspan]:nth-child(3), th:nth-child(4), th:nth-child(7) , th:nth-child(10) {
164 | border-right: 1px solid rgb(200, 200, 200) !important;
165 | }
166 |
167 | th[colspan]{
168 | text-align: center !important; /* Center align text in header cells */
169 | }
170 | th[rowspan] {
171 | vertical-align: middle !important; /* Center align text vertically in header cells with rowspan */
172 | }
173 |
174 | th, td {
175 | text-align: right !important; /* Center align text in header cells */
176 | }
177 |
178 |
--------------------------------------------------------------------------------
/docs/static/js/bulma-slider.min.js:
--------------------------------------------------------------------------------
1 | !function(t,e){"object"==typeof exports&&"object"==typeof module?module.exports=e():"function"==typeof define&&define.amd?define([],e):"object"==typeof exports?exports.bulmaSlider=e():t.bulmaSlider=e()}("undefined"!=typeof self?self:this,function(){return function(n){var r={};function i(t){if(r[t])return r[t].exports;var e=r[t]={i:t,l:!1,exports:{}};return n[t].call(e.exports,e,e.exports,i),e.l=!0,e.exports}return i.m=n,i.c=r,i.d=function(t,e,n){i.o(t,e)||Object.defineProperty(t,e,{configurable:!1,enumerable:!0,get:n})},i.n=function(t){var e=t&&t.__esModule?function(){return t.default}:function(){return t};return i.d(e,"a",e),e},i.o=function(t,e){return Object.prototype.hasOwnProperty.call(t,e)},i.p="",i(i.s=0)}([function(t,e,n){"use strict";Object.defineProperty(e,"__esModule",{value:!0}),n.d(e,"isString",function(){return l});var r=n(1),i=Object.assign||function(t){for(var e=1;e=l.length&&(s=!0)):s=!0),s&&(t.once&&(u[e]=null),t.callback(r))});-1!==u.indexOf(null);)u.splice(u.indexOf(null),1)}}]),e}();e.a=i}]).default});
--------------------------------------------------------------------------------
/docs/static/js/index.js:
--------------------------------------------------------------------------------
1 | window.HELP_IMPROVE_VIDEOJS = false;
2 |
3 | $(document).ready(function() {
4 | // Check for click events on the navbar burger icon
5 | $(".navbar-burger").click(function() {
6 | // Toggle the "is-active" class on both the "navbar-burger" and the "navbar-menu"
7 | $(".navbar-burger").toggleClass("is-active");
8 | $(".navbar-menu").toggleClass("is-active");
9 |
10 | });
11 |
12 | var options = {
13 | slidesToScroll: 1,
14 | slidesToShow: 3,
15 | loop: true,
16 | infinite: true,
17 | autoplay: false,
18 | autoplaySpeed: 3000,
19 | }
20 |
21 | // Initialize all div with carousel class
22 | var carousels = bulmaCarousel.attach('.carousel', options);
23 |
24 | // Loop on each carousel initialized
25 | for(var i = 0; i < carousels.length; i++) {
26 | // Add listener to event
27 | carousels[i].on('before:show', state => {
28 | console.log(state);
29 | });
30 | }
31 |
32 | // Access to bulmaCarousel instance of an element
33 | var element = document.querySelector('#my-element');
34 | if (element && element.bulmaCarousel) {
35 | // bulmaCarousel instance is available as element.bulmaCarousel
36 | element.bulmaCarousel.on('before-show', function(state) {
37 | console.log(state);
38 | });
39 | }
40 |
41 | /*var player = document.getElementById('interpolation-video');
42 | player.addEventListener('loadedmetadata', function() {
43 | $('#interpolation-slider').on('input', function(event) {
44 | console.log(this.value, player.duration);
45 | player.currentTime = player.duration / 100 * this.value;
46 | })
47 | }, false);*/
48 | preloadInterpolationImages();
49 |
50 | $('#interpolation-slider').on('input', function(event) {
51 | setInterpolationImage(this.value);
52 | });
53 | setInterpolationImage(0);
54 | $('#interpolation-slider').prop('max', NUM_INTERP_FRAMES - 1);
55 |
56 | bulmaSlider.attach();
57 |
58 | })
59 |
--------------------------------------------------------------------------------
/docs/static/poster.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/basiclab/MAD/f8b0b012eb3886274df41da9dee0eed6e07e3a34/docs/static/poster.pdf
--------------------------------------------------------------------------------
/generate_text_editing.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | from typing import List, Optional, Tuple
5 |
6 | from joblib import Parallel, delayed
7 | from loguru import logger
8 | from tqdm import tqdm
9 |
10 | from modeling.text_translation import TextTranslationDiffusion
11 |
12 |
13 | def copy_parameters(from_parameters, to_parameters):
14 | to_parameters = list(to_parameters)
15 | assert len(from_parameters) == len(to_parameters)
16 | for s_param, param in zip(from_parameters, to_parameters):
17 | param.data.copy_(s_param.to(param.device).data)
18 |
19 |
20 | def parse_args():
21 | parser = argparse.ArgumentParser()
22 | parser.add_argument("--save-folder", default="batch_images", type=str)
23 | parser.add_argument("--source-root", required=True, type=str)
24 | parser.add_argument("--source-list", required=True, type=str)
25 | parser.add_argument("--num-process", default=1, type=int)
26 | parser.add_argument("--num-of-step", default=180, type=int)
27 | parser.add_argument("--img-size", default=512, type=int)
28 | parser.add_argument("--model-path", default=None, type=str)
29 | parser.add_argument("--scheduler", default="ddpm", type=str)
30 | parser.add_argument("--sample-steps", default=1000, type=int)
31 | return parser.parse_args()
32 |
33 |
34 | def generate_image(
35 | img_size: int,
36 | save_folder: str,
37 | source_list: List[Tuple[str, str]],
38 | offset: int,
39 | device: str,
40 | num_of_step: int,
41 | scheduler: str,
42 | sample_steps: int,
43 | model_path: Optional[str] = None,
44 | ):
45 | diffuser = TextTranslationDiffusion(
46 | img_size=img_size,
47 | scheduler=scheduler,
48 | device=device,
49 | model_path=model_path,
50 | sample_steps=sample_steps,
51 | )
52 | os.makedirs(args.save_folder, exist_ok=True)
53 |
54 | progress_bar = tqdm(total=len(source_list), position=int(device.split(":")[-1]))
55 | count_error = 0
56 |
57 | for idx, (source_image, source_mask, editing_prompt) in enumerate(source_list):
58 | save_image_name = os.path.join(save_folder, f"pred_{idx + offset}.png")
59 | if os.path.exists(save_image_name):
60 | progress_bar.update(1)
61 | continue
62 | if source_mask.endswith("jpg"):
63 | source_mask = source_mask.replace("jpg", "png")
64 | try:
65 | editing_result = diffuser.modify_with_text(
66 | image=source_image,
67 | mask=source_mask,
68 | prompt=[editing_prompt],
69 | start_from_step=num_of_step,
70 | guidance_scale=15,
71 | )
72 | except Exception as e:
73 | logger.error(str(e))
74 | count_error += 1
75 | continue
76 | save_image = editing_result[0]
77 | save_image.save(save_image_name)
78 | progress_bar.update(1)
79 | progress_bar.close()
80 |
81 | if count_error != 0:
82 | print(f"Error in {device}: {count_error}")
83 |
84 |
85 | if __name__ == "__main__":
86 | args = parse_args()
87 |
88 | with open(args.source_list, "r") as f:
89 | data = json.load(f)
90 | source_list = []
91 | for info in data.items():
92 | image_path = os.path.join(args.source_root, info["image"])
93 | mask_path = image_path.replace("images", "parsing")
94 | source_list.append((image_path, mask_path, info["style"]))
95 |
96 | task_per_process = len(source_list) // args.num_process
97 | Parallel(n_jobs=args.num_process)(
98 | delayed(generate_image)(
99 | args.img_size,
100 | args.save_folder,
101 | source_list=source_list[
102 | gpu_idx
103 | * task_per_process : (
104 | ((gpu_idx + 1) * task_per_process)
105 | if gpu_idx < args.num_process - 1
106 | else len(source_list)
107 | )
108 | ],
109 | offset=gpu_idx * task_per_process,
110 | device=f"cuda:{gpu_idx}",
111 | num_of_step=args.num_of_step,
112 | model_path=args.model_path,
113 | scheduler=args.scheduler,
114 | sample_steps=args.sample_steps,
115 | )
116 | for gpu_idx in range(args.num_process)
117 | )
118 |
--------------------------------------------------------------------------------
/generate_transfer.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from typing import List, Tuple
4 |
5 | import torch
6 | from joblib import Parallel, delayed
7 | from loguru import logger
8 | from PIL import Image
9 | from tqdm import tqdm
10 |
11 | from config import create_cfg, merge_possible_with_base, show_config
12 | from modeling import build_model
13 | from modeling.translation import TranslationDiffusion
14 |
15 |
16 | def parse_args():
17 | parser = argparse.ArgumentParser()
18 | parser.add_argument("--config", default=None, type=str)
19 | parser.add_argument("--save-folder", default="batch_images", type=str)
20 | parser.add_argument("--source-root", required=True, type=str)
21 | parser.add_argument("--target-root", required=True, type=str)
22 | parser.add_argument("--source-list", required=True, type=str)
23 | parser.add_argument("--target-list", required=True, type=str)
24 | parser.add_argument("--source-label", required=True, type=int)
25 | parser.add_argument("--target-label", required=True, type=int)
26 | parser.add_argument("--num-process", default=1, type=int)
27 | parser.add_argument("--inpainting", action="store_true", default=False)
28 | parser.add_argument("--cam", action="store_true", default=False)
29 | parser.add_argument("--opts", nargs=argparse.REMAINDER, default=None, type=str)
30 | return parser.parse_args()
31 |
32 |
33 | def generate_image(
34 | cfg,
35 | save_folder: str,
36 | source_list: List[Tuple[str, str]],
37 | target_list: List[Tuple[str, str]],
38 | source_label: int,
39 | target_label: int,
40 | offset: int,
41 | device: str,
42 | inpainting: bool,
43 | use_cam: bool,
44 | ):
45 | torch.cuda.set_device(device)
46 | model = build_model(cfg).to(device)
47 | model.eval()
48 |
49 | diffuser = TranslationDiffusion(cfg, device)
50 | os.makedirs(args.save_folder, exist_ok=True)
51 |
52 | count_error = 0
53 | with tqdm(total=len(source_list), position=int(device.split(":")[-1])) as progress_bar:
54 | for idx, ((source_image, source_mask), (target_image, target_mask)) in enumerate(
55 | zip(source_list, target_list)
56 | ):
57 | save_image_name = os.path.join(save_folder, f"pred_{idx + offset}.png")
58 | if os.path.exists(save_image_name):
59 | progress_bar.update(1)
60 | continue
61 | if source_mask.endswith("jpg"):
62 | source_mask = source_mask.replace("jpg", "png")
63 | if target_mask.endswith("jpg"):
64 | target_mask = target_mask.replace("jpg", "png")
65 |
66 | try:
67 | transfer_result = diffuser.image_translation(
68 | source_model=model,
69 | target_model=model,
70 | source_image=source_image,
71 | target_image=target_image,
72 | source_class_label=source_label,
73 | target_class_label=target_label,
74 | source_parsing_mask=source_mask,
75 | target_parsing_mask=target_mask,
76 | use_morphing=True,
77 | use_encode_eps=True,
78 | use_cam=use_cam,
79 | inpainting=inpainting,
80 | )
81 | except Exception as e:
82 | logger.error(f"Error in {device}: {e}")
83 | count_error += 1
84 | continue
85 | save_image = Image.fromarray(
86 | (transfer_result[0].permute(1, 2, 0).cpu().numpy() * 255).astype("uint8")
87 | )
88 | save_image.save(save_image_name)
89 | progress_bar.update(1)
90 | if count_error != 0:
91 | print(f"Error in {device}: {count_error}")
92 |
93 |
94 | if __name__ == "__main__":
95 | args = parse_args()
96 | cfg = create_cfg()
97 | if args.config:
98 | merge_possible_with_base(cfg, args.config)
99 | if args.opts:
100 | cfg.merge_from_list(args.opts)
101 | show_config(cfg)
102 |
103 | source_list = []
104 | target_list = []
105 |
106 | with open(args.source_list, "r") as f_source, open(args.target_list, "r") as f_target:
107 | source_lines = f_source.readlines()
108 | target_lines = f_target.readlines()
109 | for idx in range(len(source_lines)):
110 | if source_lines[idx].strip() == "" or target_lines[idx].strip() == "":
111 | continue
112 | source_path = os.path.join(args.source_root, source_lines[idx].strip())
113 | source_list.append((source_path, source_path.replace("images", "parsing")))
114 | target_path = os.path.join(args.target_root, target_lines[idx].strip())
115 | target_list.append((target_path, target_path.replace("images", "parsing")))
116 |
117 | assert len(source_list) == len(target_list), "Source and target list should have same length"
118 | task_per_process = len(source_list) // args.num_process
119 | Parallel(n_jobs=args.num_process)(
120 | delayed(generate_image)(
121 | cfg,
122 | args.save_folder,
123 | source_list=source_list[
124 | gpu_idx
125 | * task_per_process : (
126 | ((gpu_idx + 1) * task_per_process)
127 | if gpu_idx < args.num_process - 1
128 | else len(source_list)
129 | )
130 | ],
131 | target_list=target_list[
132 | gpu_idx
133 | * task_per_process : (
134 | ((gpu_idx + 1) * task_per_process)
135 | if gpu_idx < args.num_process - 1
136 | else len(target_list)
137 | )
138 | ],
139 | source_label=args.source_label,
140 | target_label=args.target_label,
141 | offset=gpu_idx * task_per_process,
142 | device=f"cuda:{gpu_idx}",
143 | inpainting=args.inpainting,
144 | use_cam=args.cam,
145 | )
146 | for gpu_idx in range(args.num_process)
147 | )
148 |
--------------------------------------------------------------------------------
/generate_translation.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from typing import List, Tuple
4 |
5 | import torch
6 | from joblib import Parallel, delayed
7 | from loguru import logger
8 | from PIL import Image
9 | from tqdm import tqdm
10 |
11 | from config import create_cfg, merge_possible_with_base, show_config
12 | from modeling import build_model
13 | from modeling.translation import TranslationDiffusion
14 |
15 |
16 | def parse_args():
17 | parser = argparse.ArgumentParser()
18 | parser.add_argument("--config", default=None, type=str)
19 | parser.add_argument("--save-folder", default="batch_images", type=str)
20 | parser.add_argument("--source-root", required=True, type=str)
21 | parser.add_argument("--source-list", required=True, type=str)
22 | parser.add_argument("--source-label", required=True, type=int)
23 | parser.add_argument("--target-label", required=True, type=int)
24 | parser.add_argument("--num-process", default=1, type=int)
25 | parser.add_argument("--num-of-step", default=700, type=int)
26 | parser.add_argument("--opts", nargs=argparse.REMAINDER, default=None, type=str)
27 | return parser.parse_args()
28 |
29 |
30 | def generate_image(
31 | cfg,
32 | save_folder: str,
33 | source_list: List[Tuple[str, str]],
34 | source_label: int,
35 | target_label: int,
36 | offset: int,
37 | device: str,
38 | num_of_step: int,
39 | ):
40 | torch.cuda.set_device(device)
41 | model = build_model(cfg).to(device)
42 | model.eval()
43 |
44 | diffuser = TranslationDiffusion(cfg, device)
45 | os.makedirs(args.save_folder, exist_ok=True)
46 |
47 | count_error = 0
48 | with tqdm(
49 | total=len(source_list), position=int(device.split(":")[-1])
50 | ) as progress_bar:
51 | for idx, (source_image, source_mask) in enumerate(source_list):
52 | save_image_name = os.path.join(save_folder, f"pred_{idx + offset}.png")
53 | if os.path.exists(save_image_name):
54 | progress_bar.update(1)
55 | continue
56 | if source_mask.endswith("jpg"):
57 | source_mask = source_mask.replace("jpg", "png")
58 | try:
59 | transfer_result = diffuser.domain_translation(
60 | source_model=model,
61 | target_model=model,
62 | source_image=source_image,
63 | source_class_label=source_label,
64 | target_class_label=target_label,
65 | parsing_mask=source_mask,
66 | start_from_step=num_of_step,
67 | )
68 | except Exception as e:
69 | logger.error(str(e))
70 | count_error += 1
71 | continue
72 | save_image = Image.fromarray(
73 | (transfer_result[0].permute(1, 2, 0).cpu().numpy() * 255).astype(
74 | "uint8"
75 | )
76 | )
77 | save_image.save(save_image_name)
78 |
79 | if count_error != 0:
80 | print(f"Error in {device}: {count_error}")
81 |
82 |
83 | if __name__ == "__main__":
84 | args = parse_args()
85 | cfg = create_cfg()
86 | if args.config:
87 | merge_possible_with_base(cfg, args.config)
88 | if args.opts:
89 | cfg.merge_from_list(args.opts)
90 | show_config(cfg)
91 |
92 | source_list = []
93 | with open(args.source_list, "r") as f:
94 | for line in f.readlines():
95 | source_line = line.strip()
96 | image_path = os.path.join(args.source_root, source_line)
97 | mask_path = image_path.replace("images", "parsing")
98 | source_list.append((image_path, mask_path))
99 |
100 | task_per_process = len(source_list) // args.num_process
101 | Parallel(n_jobs=args.num_process)(
102 | delayed(generate_image)(
103 | cfg,
104 | args.save_folder,
105 | source_list=source_list[
106 | gpu_idx * task_per_process : (
107 | ((gpu_idx + 1) * task_per_process)
108 | if gpu_idx < args.num_process - 1
109 | else len(source_list)
110 | )
111 | ],
112 | source_label=args.source_label,
113 | target_label=args.target_label,
114 | offset=gpu_idx * task_per_process,
115 | device=f"cuda:{gpu_idx}",
116 | num_of_step=args.num_of_step,
117 | )
118 | for gpu_idx in range(args.num_process)
119 | )
120 |
121 |
--------------------------------------------------------------------------------
/misc/__init__.py:
--------------------------------------------------------------------------------
1 | from .meter import AverageMeter, MetricMeter
2 |
3 | __all__ = ["AverageMeter", "MetricMeter"]
4 |
--------------------------------------------------------------------------------
/misc/compute_classification.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import glob
3 | import os
4 |
5 | import tabulate
6 | import torch
7 | import torch.nn as nn
8 | import torch.utils.data
9 | import torchvision.models as models
10 | from PIL import Image, ImageFile
11 | from torchvision import transforms
12 | from tqdm import tqdm
13 |
14 | ImageFile.LOAD_TRUNCATED_IMAGES = True
15 |
16 | ROOT = {
17 | "mt": "data/mtdataset/images",
18 | "mt_removal": "generate_outputs/mt_removal",
19 | "beauty": "data/beautyface/images",
20 | "beauty_removal": "generate_outputs/beauty_removal",
21 | "wild": "data/wild/images",
22 | "wild_removal": "generate_outputs/wild_removal",
23 | }
24 |
25 |
26 | class InceptionModel(torch.nn.Module):
27 | def __init__(self, num_class=1000):
28 | super(InceptionModel, self).__init__()
29 | self.Conv2d_1a_3x3 = BasicConv2d(3, 32, kernel_size=3, stride=2)
30 | self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3)
31 | self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1)
32 | self.MaxPool_1 = nn.MaxPool2d(kernel_size=3, stride=2)
33 | self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1)
34 | self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3)
35 | self.MaxPool_2 = nn.MaxPool2d(kernel_size=3, stride=2)
36 | self.Mixed_5b = InceptionA(192, pool_features=32)
37 | self.Mixed_5c = InceptionA(256, pool_features=64)
38 | self.Mixed_5d = InceptionA(288, pool_features=64)
39 | self.Mixed_6a = InceptionB(288)
40 | self.Mixed_6b = InceptionC(768, channels_7x7=128)
41 | self.Mixed_6c = InceptionC(768, channels_7x7=160)
42 | self.Mixed_6d = InceptionC(768, channels_7x7=160)
43 | self.Mixed_6e = InceptionC(768, channels_7x7=192)
44 | self.Mixed_7a = InceptionD(768)
45 | self.Mixed_7b = InceptionE_1(1280)
46 | self.Mixed_7c = InceptionE_2(2048)
47 | self.AvgPool = nn.AdaptiveAvgPool2d(output_size=(1, 1))
48 | self.fc = nn.Linear(2048, num_class)
49 |
50 | def forward(self, x):
51 | x = interpolate_bilinear_2d_like_tensorflow1x(
52 | x,
53 | size=(299, 299),
54 | align_corners=False,
55 | )
56 | x = self.Conv2d_1a_3x3(x)
57 | x = self.Conv2d_2a_3x3(x)
58 | x = self.Conv2d_2b_3x3(x)
59 | x = self.MaxPool_1(x)
60 | x = self.Conv2d_3b_1x1(x)
61 | x = self.Conv2d_4a_3x3(x)
62 | x = self.MaxPool_2(x)
63 | x = self.Mixed_5b(x)
64 | x = self.Mixed_5c(x)
65 | x = self.Mixed_5d(x)
66 | x = self.Mixed_6a(x)
67 | x = self.Mixed_6b(x)
68 | x = self.Mixed_6c(x)
69 | x = self.Mixed_6d(x)
70 | x = self.Mixed_6e(x)
71 | x = self.Mixed_7a(x)
72 | x = self.Mixed_7b(x)
73 | x = self.Mixed_7c(x)
74 | x = self.AvgPool(x)
75 | x = torch.flatten(x, 1)
76 | x = self.fc(x)
77 |
78 | return x
79 |
80 |
81 | def get_label(image_root):
82 | all_images = []
83 | for root in image_root:
84 | all_images.extend(
85 | list(glob.glob(os.path.join(root, "**/*.png"), recursive=True))
86 | + list(glob.glob(os.path.join(root, "**/*.jpg"), recursive=True))
87 | )
88 | label = {name: idx for idx, name in enumerate(sorted(set(all_images)))}
89 | return label
90 |
91 |
92 | def main(args):
93 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
94 | label_dict = get_label(["data/mtdataset/images", "data/beautyface/images", "data/wild/images"])
95 | print(len(label_dict))
96 | model = InceptionModel(num_class=len(label_dict)).to(device)
97 | model.load_state_dict(torch.load("inception.pth", map_location="cpu"))
98 | model = model.to(device)
99 | model.eval()
100 |
101 | transform_list = transforms.Compose(
102 | [
103 | transforms.Resize((224, 224), antialias=True),
104 | transforms.CenterCrop((224, 224)),
105 | transforms.ToTensor(),
106 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
107 | ]
108 | )
109 |
110 | table = []
111 | target_non_makeup_img = []
112 | with open(args.non_makeup_file, "r") as f:
113 | for line in f:
114 | if line.strip() != "":
115 | target_non_makeup_img.append(os.path.join(ROOT["mt"], line.strip()))
116 | for target in tqdm(args.path_list):
117 | if not os.path.exists(target):
118 | continue
119 | target_file_sorted = sorted(
120 | list(glob.glob(f"{target}/*.png")),
121 | key=lambda x: int(os.path.basename(x).split(".")[0].split("_")[-1]),
122 | )
123 | selected_non_makeup_image = []
124 | for file in target_file_sorted:
125 | num = int(os.path.basename(file).split(".")[0].split("_")[-1])
126 | selected_non_makeup_image.append(target_non_makeup_img[num])
127 |
128 | score = 0
129 | for img_name, label_name in zip(target_file_sorted, selected_non_makeup_image):
130 | target_label = label_dict[label_name]
131 |
132 | img = Image.open(img_name).convert("RGB")
133 | img = transform_list(img).unsqueeze(0).to(device)
134 |
135 | with torch.no_grad():
136 | output = model(img)
137 | output = nn.functional.softmax(output, dim=1)[0].cpu().tolist()
138 | score += output[target_label]
139 | score /= len(selected_non_makeup_image)
140 | table.append([target.split("/")[1], f"{score:.3f}"])
141 | print(tabulate.tabulate(table, headers=["Approach", "Acc"], tablefmt="grid"))
142 |
143 |
144 | if __name__ == "__main__":
145 | parser = argparse.ArgumentParser(description="Compute KID")
146 | parser.add_argument(
147 | "--non-makeup-file",
148 | help="File denoting the order of makeup image",
149 | default="data/nomakeup_test_mt.txt",
150 | )
151 | parser.add_argument(
152 | "--path-list",
153 | help="List of path for evaluation",
154 | nargs="+",
155 | type=str,
156 | required=True,
157 | )
158 | parser.add_argument(
159 | "--type",
160 | help="Type of dataset",
161 | choices=["mt", "beauty", "wild"],
162 | default="mt",
163 | )
164 | args = parser.parse_args()
165 | main(args)
166 |
--------------------------------------------------------------------------------
/misc/compute_metrics.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import sys
4 |
5 | from tqdm import tqdm
6 |
7 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
8 |
9 | import argparse
10 | import glob
11 |
12 | import numpy as np
13 | import tabulate
14 | import torch
15 | import torch.nn
16 | import torchvision.transforms.functional as F
17 | from PIL import Image, ImageFile
18 | from torchvision import transforms
19 |
20 | import torch_fidelity
21 |
22 | ImageFile.LOAD_TRUNCATED_IMAGES = True
23 |
24 | ROOT = {
25 | "mt": "data/mtdataset/images",
26 | "mt_removal": "generate_outputs/mt_removal",
27 | "beauty": "data/beautyface/images",
28 | "beauty_removal": "generate_outputs/beauty_removal",
29 | "wild": "data/wild/images",
30 | "wild_removal": "generate_outputs/wild_removal",
31 | }
32 |
33 |
34 | def set_seed(seed: int = 385832):
35 | random.seed(seed)
36 | np.random.seed(seed)
37 | torch.manual_seed(seed)
38 | torch.cuda.manual_seed(seed)
39 | torch.cuda.manual_seed_all(seed)
40 | torch.backends.cudnn.deterministic = True
41 | torch.backends.cudnn.benchmark = False
42 |
43 |
44 | class TransformPILtoRGBTensor:
45 | def __call__(self, img):
46 | return F.pil_to_tensor(img)
47 |
48 |
49 | class ImagesPathDataset(torch.utils.data.Dataset):
50 | def __init__(self, files, img_size=224):
51 | self.files = files
52 | self.transforms = transforms.Compose(
53 | [
54 | transforms.Resize((img_size, img_size), antialias=True),
55 | transforms.CenterCrop((img_size, img_size)),
56 | TransformPILtoRGBTensor(),
57 | ]
58 | )
59 |
60 | def __len__(self):
61 | return len(self.files)
62 |
63 | def __getitem__(self, i):
64 | path = self.files[i]
65 | img = Image.open(path).convert("RGB")
66 |
67 | img = self.transforms(img)
68 | return img
69 |
70 |
71 | class CheckImagesPathDataset(torch.utils.data.Dataset):
72 | def __init__(
73 | self,
74 | root,
75 | ori_file=None,
76 | target_files=None,
77 | order_files=None,
78 | img_size=224,
79 | ):
80 | self.files = []
81 | if ori_file is None:
82 | for file in order_files:
83 | num = int(os.path.basename(file).split(".")[0].split("_")[-1])
84 | self.files.append(os.path.join(root, f"pred_{num}.png"))
85 | else:
86 | with open(ori_file, "r") as f:
87 | line_of_file = f.readlines()
88 | for file in target_files:
89 | num = int(os.path.basename(file).split(".")[0].split("_")[-1])
90 | self.files.append(os.path.join(root, line_of_file[num].strip()))
91 | self.transforms = transforms.Compose(
92 | [
93 | transforms.Resize((img_size, img_size), antialias=True),
94 | transforms.CenterCrop((img_size, img_size)),
95 | TransformPILtoRGBTensor(),
96 | ]
97 | )
98 |
99 | def __len__(self):
100 | return len(self.files)
101 |
102 | def __getitem__(self, i):
103 | path = self.files[i]
104 | img = Image.open(path).convert("RGB")
105 | img = self.transforms(img)
106 | return img
107 |
108 |
109 | def main(args):
110 | table = []
111 | target_non_makeup_img = []
112 | with open(args.non_makeup_file, "r") as f:
113 | for line in f:
114 | target_non_makeup_img.append(os.path.join(ROOT["mt"], line.strip()))
115 | for target in tqdm(args.path_list):
116 | if not os.path.exists(target):
117 | continue
118 | target_file_sorted = sorted(
119 | list(glob.glob(f"{target}/*.png")),
120 | key=lambda x: int(os.path.basename(x).split(".")[0].split("_")[-1]),
121 | )
122 | selected_non_makeup_image = []
123 | for file in target_file_sorted:
124 | num = int(os.path.basename(file).split(".")[0].split("_")[-1])
125 | selected_non_makeup_image.append(target_non_makeup_img[num])
126 | precision = torch_fidelity.calculate_metrics(
127 | input1=ImagesPathDataset(files=target_file_sorted),
128 | input2=ImagesPathDataset(files=selected_non_makeup_image),
129 | input3=CheckImagesPathDataset(
130 | root=ROOT[args.type],
131 | ori_file=args.makeup_file,
132 | target_files=target_file_sorted,
133 | ),
134 | prc=True,
135 | device="cuda",
136 | verbose=False,
137 | cache=False,
138 | feature_extractor="vgg16",
139 | feature_extractor_weights_path="vgg.pth",
140 | )["precision"]
141 |
142 | recall = torch_fidelity.calculate_metrics(
143 | input1=ImagesPathDataset(files=target_file_sorted),
144 | input2=CheckImagesPathDataset(
145 | root=ROOT[args.type],
146 | ori_file=args.makeup_file,
147 | target_files=target_file_sorted,
148 | ),
149 | input3=CheckImagesPathDataset( # Use mt to remove original non-makeup feature
150 | root=ROOT["mt"],
151 | ori_file=args.non_makeup_file,
152 | target_files=target_file_sorted,
153 | ),
154 | input4=CheckImagesPathDataset(
155 | root=ROOT[
156 | f"{args.type}_removal"
157 | ], # Use makeup to non-makeup image to remove non-makeup feature
158 | order_files=target_file_sorted,
159 | ),
160 | prc=True,
161 | device="cuda",
162 | cache=False,
163 | verbose=False,
164 | feature_extractor="vgg16",
165 | feature_extractor_weights_path="vgg.pth",
166 | )["recall"]
167 |
168 | kid = torch_fidelity.calculate_metrics(
169 | input1=ImagesPathDataset(files=target_file_sorted),
170 | input2=CheckImagesPathDataset(
171 | root=ROOT[args.type],
172 | ori_file=args.makeup_file,
173 | target_files=target_file_sorted,
174 | ),
175 | input3=CheckImagesPathDataset( # Use mt to remove original non-makeup feature
176 | root=ROOT["mt"],
177 | ori_file=args.non_makeup_file,
178 | target_files=target_file_sorted,
179 | ),
180 | input4=CheckImagesPathDataset(
181 | root=ROOT[
182 | f"{args.type}_removal"
183 | ], # Use makeup to non-makeup image to remove non-makeup feature
184 | order_files=target_file_sorted,
185 | ),
186 | kid=True,
187 | device="cuda",
188 | cache=False,
189 | verbose=False,
190 | feature_extractor="inception-v3-compat",
191 | feature_extractor_weights_path="inception.pth",
192 | kid_subset_size=min(1000, len(target_file_sorted)),
193 | )["kernel_inception_distance_mean"]
194 |
195 | table.append(
196 | [
197 | target.split("/")[1],
198 | f"{precision:.3f}",
199 | f"{recall:.3f}",
200 | f"{kid:.3f}",
201 | ]
202 | )
203 | print(
204 | tabulate.tabulate(
205 | table, headers=["Approach", "Precision", "Recall", "KID"], tablefmt="fancy_grid"
206 | )
207 | )
208 |
209 |
210 | if __name__ == "__main__":
211 | parser = argparse.ArgumentParser(description="Compute KID")
212 | parser.add_argument(
213 | "--non-makeup-file",
214 | help="File denoting the order of makeup image",
215 | default="data/nomakeup_test_mt.txt",
216 | )
217 | parser.add_argument(
218 | "--makeup-file",
219 | help="File denoting the order of makeup image",
220 | default="data/makeup_test_mt.txt",
221 | )
222 | parser.add_argument(
223 | "--path-list",
224 | help="List of path for evaluation",
225 | nargs="+",
226 | type=str,
227 | required=True,
228 | )
229 | parser.add_argument(
230 | "--type",
231 | help="Type of dataset",
232 | choices=["mt", "beauty", "wild"],
233 | default="mt",
234 | )
235 | args = parser.parse_args()
236 | set_seed()
237 | main(args)
238 |
--------------------------------------------------------------------------------
/misc/compute_removal_metrics.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import glob
3 | import os
4 |
5 | import cv2
6 | import skimage.metrics
7 | import tabulate
8 | from tqdm import tqdm
9 |
10 | ROOT = {
11 | "mt": "data/mtdataset/images",
12 | }
13 |
14 |
15 | def main(args):
16 | table = []
17 | for target in tqdm(args.path_list):
18 | if not os.path.exists(target):
19 | raise FileNotFoundError(f"{target} not found")
20 | target_file_sorted = sorted(
21 | list(glob.glob(f"{target}/pred_*.png")),
22 | key=lambda x: int(os.path.basename(x).split(".")[0].split("_")[-1]),
23 | )
24 | target_non_makeup_img = []
25 | with open(args.origin_file, "r") as f:
26 | for line in f:
27 | if line.strip() != "":
28 | target_non_makeup_img.append(os.path.join(ROOT["mt"], line.strip()))
29 |
30 | total_ssim = 0
31 | total_psnr = 0
32 | for file in target_file_sorted:
33 | idx = int(os.path.basename(file).split(".")[0].rsplit("_", 1)[1])
34 | pred = cv2.imread(file)
35 | non_makeup = cv2.imread(target_non_makeup_img[idx])
36 | pred = cv2.resize(pred, (non_makeup.shape[1], non_makeup.shape[0]))
37 | ssim = skimage.metrics.structural_similarity(
38 | cv2.cvtColor(pred, cv2.COLOR_BGR2GRAY),
39 | cv2.cvtColor(non_makeup, cv2.COLOR_BGR2GRAY),
40 | )
41 | psnr = skimage.metrics.peak_signal_noise_ratio(
42 | cv2.cvtColor(pred, cv2.COLOR_BGR2GRAY),
43 | cv2.cvtColor(non_makeup, cv2.COLOR_BGR2GRAY),
44 | )
45 | total_ssim += ssim
46 | total_psnr += psnr
47 | table.append(
48 | [
49 | target,
50 | total_ssim / len(target_file_sorted),
51 | total_psnr / len(target_file_sorted),
52 | ]
53 | )
54 | print(tabulate.tabulate(table, headers=["Appraoch", "SSIM", "PSNR"], tablefmt="fancy_grid"))
55 |
56 |
57 | if __name__ == "__main__":
58 | parser = argparse.ArgumentParser(description="Compute KID")
59 | parser.add_argument(
60 | "--origin-file",
61 | help="File denoting the order of makeup image",
62 | default="data/nomakeup_test_mt.txt",
63 | )
64 | parser.add_argument(
65 | "--path-list",
66 | help="List of path for evaluation",
67 | nargs="+",
68 | type=str,
69 | required=True,
70 | )
71 | args = parser.parse_args()
72 | main(args)
73 |
--------------------------------------------------------------------------------
/misc/compute_text_metrics.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 |
5 | import clip
6 | import tabulate
7 | import torch
8 | from PIL import Image
9 | from tqdm import tqdm
10 |
11 |
12 | def parse_args():
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument("--test-file", type=str, required=True)
15 | parser.add_argument("--model-path", type=str, required=True)
16 | parser.add_argument("--orig-root", type=str, required=True)
17 | parser.add_argument("--test-dir", type=str, nargs="+", required=True)
18 | return parser.parse_args()
19 |
20 |
21 | def compute_clip_text(model, preprocess, image_dir, test_meta):
22 | all_images = []
23 | all_texts = []
24 | for idx, val in enumerate(tqdm(list(test_meta.values()))):
25 | all_images.append(preprocess(Image.open(os.path.join(image_dir, f"pred_{idx}.png"))))
26 | all_texts.append(f"makeup with {', '.join(val['style'])}")
27 | image = torch.stack(all_images).to(device)
28 | text = clip.tokenize(all_texts).to(device)
29 |
30 | with torch.no_grad():
31 | image_features = model.encode_image(image)
32 | text_features = model.encode_text(text)
33 |
34 | image_features = image_features / image_features.norm(dim=1, keepdim=True)
35 | text_features = text_features / text_features.norm(dim=1, keepdim=True)
36 | similarity = torch.nn.functional.cosine_similarity(image_features, text_features).mean()
37 | return similarity
38 |
39 |
40 | def compute_image_similarity(model, preprocess, image_dir, orig_dir, test_meta):
41 | all_ori_images = []
42 | all_result_images = []
43 | for idx, val in enumerate(tqdm(list(test_meta.values()))):
44 | all_ori_images.append(preprocess(Image.open(os.path.join(orig_dir, val["nonmakeup"]))))
45 | all_result_images.append(preprocess(Image.open(os.path.join(image_dir, f"pred_{idx}.png"))))
46 |
47 | ori_image = torch.stack(all_ori_images).to(device)
48 | result_image = torch.stack(all_result_images).to(device)
49 |
50 | with torch.no_grad():
51 | ori_image_features = model.encode_image(ori_image)
52 | result_image_features = model.encode_image(result_image)
53 |
54 | ori_image_features = ori_image_features / ori_image_features.norm(dim=1, keepdim=True)
55 | result_image_features = result_image_features / result_image_features.norm(
56 | dim=1, keepdim=True
57 | )
58 | similarity = torch.nn.functional.cosine_similarity(
59 | ori_image_features, result_image_features
60 | ).mean()
61 | return similarity
62 |
63 |
64 | def compute_style_similarity(model, preprocess, image_dir, orig_dir, test_meta):
65 | all_ori_images = []
66 | all_result_images = []
67 | for idx, key in enumerate(tqdm(list(test_meta.keys()))):
68 | all_ori_images.append(preprocess(Image.open(os.path.join(orig_dir, key))))
69 | all_result_images.append(preprocess(Image.open(os.path.join(image_dir, f"pred_{idx}.png"))))
70 |
71 | ori_image = torch.stack(all_ori_images).to(device)
72 | result_image = torch.stack(all_result_images).to(device)
73 |
74 | with torch.no_grad():
75 | ori_image_features = model.encode_image(ori_image)
76 | result_image_features = model.encode_image(result_image)
77 |
78 | ori_image_features = ori_image_features / ori_image_features.norm(dim=1, keepdim=True)
79 | result_image_features = result_image_features / result_image_features.norm(
80 | dim=1, keepdim=True
81 | )
82 | similarity = torch.nn.functional.cosine_similarity(
83 | ori_image_features, result_image_features
84 | ).mean()
85 | return similarity
86 |
87 |
88 | if __name__ == "__main__":
89 | device = "cuda" if torch.cuda.is_available() else "cpu"
90 |
91 | args = parse_args()
92 |
93 | model, preprocess = clip.load(args.model_path, device=device)
94 |
95 | with open(args.test_file, "r") as f:
96 | test_meta = json.load(f)
97 |
98 | text_match_score = []
99 | image_match_score = []
100 | style_score = []
101 | for test_dir in args.test_dir:
102 | text_match_score.append(compute_clip_text(model, preprocess, test_dir, test_meta))
103 | image_match_score.append(
104 | compute_image_similarity(model, preprocess, test_dir, args.orig_root, test_meta)
105 | )
106 | style_score.append(
107 | compute_style_similarity(model, preprocess, test_dir, args.orig_root, test_meta)
108 | )
109 |
110 | print(
111 | tabulate.tabulate(
112 | [
113 | ["CLIP text"] + text_match_score,
114 | ["CLIP image"] + image_match_score,
115 | ["CLIp style"] + style_score,
116 | ],
117 | headers=["Metric"] + [os.path.basename(dir_name) for dir_name in args.test_dir],
118 | tablefmt="pretty",
119 | )
120 | )
121 |
--------------------------------------------------------------------------------
/misc/constant.py:
--------------------------------------------------------------------------------
1 | # fmt: off
2 | LEFT_EYE_CONTOUR = [448, 449, 450, 451, 452, 453, 464, 413, 285, 295, 282, 283, 276, 383, 265, 261]
3 | RIGHT_EYE_CONTOUR = [231, 230, 229, 228, 111, 143, 156, 46, 53, 52, 65, 55, 193, 244, 233, 232]
4 | NUM_TO_EYE_CONTOUR = {
5 | 1: RIGHT_EYE_CONTOUR,
6 | 6: LEFT_EYE_CONTOUR,
7 | }
8 |
9 | # Idx for mediapipe facial landmarks
10 | MEDIAPIPE_LIPS_IDX = [0, 267, 269, 270, 13, 14, 17, 402, 146, 405, 409, 415, 291, 37, 39, 40, 178, 308, 181, 310, 311, 312, 185, 314, 317, 318, 61, 191, 321, 324, 78, 80, 81, 82, 84, 87, 88, 91, 95, 375]
11 | MEDIAPIPE_LEFT_EYES_IDX = [384, 385, 386, 387, 388, 390, 263, 362, 398, 466, 373, 374, 249, 380, 381, 382]
12 | MEDIAPIPE_RIGHT_EYES_IDX = [160, 33, 161, 163, 133, 7, 173, 144, 145, 246, 153, 154, 155, 157, 158, 159]
13 | MEDIAPIPE_LEFT_EYEBROW_IDX = [65, 66, 70, 105, 107, 46, 52, 53, 55, 63]
14 | MEDIAPIPE_RIGHT_EYEBROW_IDX = [293, 295, 296, 300, 334, 336, 276, 282, 283, 285]
15 | MEDIAPIPE_NOSE_IDX = [1, 2, 4, 5, 6, 19, 275, 278, 294, 168, 45, 48, 440, 64, 195, 197, 326, 327, 344, 220, 94, 97, 98, 115]
16 | MEDIAPIPE_OVAL = [132, 389, 136, 10, 397, 400, 148, 149, 150, 21, 152, 284, 288, 162, 297, 172, 176, 54, 58, 323, 67, 454, 332, 338, 93, 356, 103, 361, 234, 109, 365, 379, 377, 378, 251, 127]
17 | MEDIAPIPE_LANDMARKS = MEDIAPIPE_LIPS_IDX + MEDIAPIPE_LEFT_EYES_IDX + MEDIAPIPE_RIGHT_EYES_IDX + MEDIAPIPE_LEFT_EYEBROW_IDX + MEDIAPIPE_RIGHT_EYEBROW_IDX + MEDIAPIPE_NOSE_IDX
18 | MEDIAPIPE_REMOVE_LANDMARKS = [135, 136, 169, 150, 170, 149, 140, 176, 171, 148, 175, 152, 396, 377, 369, 400, 395, 378, 394, 379, 364, 365, 367, 397, 435, 288, 361, 401, 323, 366, 454, 447, 356, 264, 368, 389, 251, 284, 54, 21, 162, 139, 127, 34, 234, 227, 93, 137, 132, 177, 215, 58, 138, 172]
19 |
20 | # Idx for dlib facial landmarks
21 | DLIB_LANDMARKS = list(range(0, 17))
22 |
--------------------------------------------------------------------------------
/misc/convert_beauty_face.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import glob
3 | import os
4 |
5 | import click
6 | import numpy as np
7 | from PIL import Image
8 | from tqdm import tqdm
9 |
10 | CONVERT_DICT = {
11 | 12: 9, # up lip
12 | 1: 4, # face
13 | 10: 8, # nose
14 | 14: 10, # neck
15 | 17: 12, # hair
16 | 4: 1, # right eye
17 | 5: 6, # left eye
18 | 2: 2, # right eyebrow
19 | 3: 7, # left eyebrow
20 | 7: 5, # right ear
21 | 8: 3, # left ear
22 | 9: 0, # ear ring
23 | 11: 11, # teeth
24 | 16: 0, # shirt
25 | }
26 |
27 |
28 | @click.command()
29 | @click.option("--original", help="Original json file", type=click.Path(exists=True), required=True)
30 | @click.option("--save_path", help="Original json file", required=True)
31 | def main(original, save_path):
32 | os.makedirs(save_path, exist_ok=True)
33 | original_mask = glob.glob(os.path.join(original, "*.png"))
34 |
35 | for mask_name in tqdm(original_mask):
36 | mask = np.array(Image.open(mask_name))
37 | new_mask = copy.deepcopy(mask)
38 | for key, value in CONVERT_DICT.items():
39 | new_mask[mask == key] = value
40 | new_mask = Image.fromarray(new_mask)
41 | new_mask.save(os.path.join(save_path, os.path.basename(mask_name)))
42 |
43 |
44 | if __name__ == "__main__":
45 | main()
46 |
--------------------------------------------------------------------------------
/misc/meter.py:
--------------------------------------------------------------------------------
1 | """Taken from
2 | https://github.com/KaiyangZhou/deep-person-reid/blob/master/torchreid/utils/avgmeter.py
3 | """
4 |
5 | from collections import defaultdict
6 |
7 | import torch
8 |
9 |
10 | class AverageMeter(object):
11 | def __init__(self):
12 | self.reset()
13 |
14 | def reset(self):
15 | self.val = 0
16 | self.avg = 0
17 | self.sum = 0
18 | self.count = 0
19 |
20 | def update(self, val, n=1):
21 | self.val = val
22 | self.sum += val * n
23 | self.count += n
24 | self.avg = self.sum / self.count
25 |
26 |
27 | class MetricMeter(object):
28 | def __init__(self, delimiter="\t"):
29 | self.meters = defaultdict(AverageMeter)
30 | self.delimiter = delimiter
31 |
32 | def update(self, input_dict):
33 | if input_dict is None:
34 | return
35 |
36 | if not isinstance(input_dict, dict):
37 | raise TypeError("Input to MetricMeter.update() must be a dictionary")
38 |
39 | for k, v in input_dict.items():
40 | if isinstance(v, torch.Tensor):
41 | v = v.item()
42 | self.meters[k].update(v)
43 |
44 | def __str__(self):
45 | output_str = []
46 | for name, meter in self.meters.items():
47 | output_str.append(f"{name} {meter.val:.4f} ({meter.avg:.4f})")
48 | return self.delimiter.join(output_str)
49 |
50 | def get_log_dict(self):
51 | log_dict = {}
52 | for name, meter in self.meters.items():
53 | log_dict[name] = meter.val
54 | log_dict[f"avg_{name}"] = meter.avg
55 | return log_dict
56 |
--------------------------------------------------------------------------------
/misc/morphing.py:
--------------------------------------------------------------------------------
1 | """
2 | Adapted from https://github.com/andrewdcampbell/face-movie/blob/master/face-movie/face_morph.py
3 | """
4 |
5 | import copy
6 | from typing import Dict
7 |
8 | import cv2
9 | import numpy as np
10 | from scipy.spatial import Delaunay
11 |
12 |
13 | def warp_im(im, src_landmarks, dst_landmarks, triangulation, flags=cv2.INTER_LINEAR):
14 | im_out = im.copy()
15 |
16 | for i in range(len(triangulation)):
17 | src_tri = src_landmarks[triangulation[i]]
18 | dst_tri = dst_landmarks[triangulation[i]]
19 | morph_triangle(im, im_out, src_tri, dst_tri, flags)
20 |
21 | return im_out
22 |
23 |
24 | def morph_triangle(im, im_out, src_tri, dst_tri, flags):
25 | sr = cv2.boundingRect(np.float32([src_tri]))
26 | dr = cv2.boundingRect(np.float32([dst_tri]))
27 | cropped_src_tri = [(src_tri[i][0] - sr[0], src_tri[i][1] - sr[1]) for i in range(3)]
28 | cropped_dst_tri = [(dst_tri[i][0] - dr[0], dst_tri[i][1] - dr[1]) for i in range(3)]
29 |
30 | mask = np.zeros((dr[3], dr[2], 3), dtype=np.float32)
31 | cv2.fillConvexPoly(mask, np.int32(cropped_dst_tri), (1.0, 1.0, 1.0), 16, 0)
32 |
33 | cropped_im = im[sr[1] : sr[1] + sr[3], sr[0] : sr[0] + sr[2]]
34 |
35 | size = (dr[2], dr[3])
36 | warpImage1 = affine_transform(cropped_im, cropped_src_tri, cropped_dst_tri, size, flags)
37 |
38 | im_out[dr[1] : dr[1] + dr[3], dr[0] : dr[0] + dr[2]] = (
39 | im_out[dr[1] : dr[1] + dr[3], dr[0] : dr[0] + dr[2]] * (1 - mask) + warpImage1 * mask
40 | )
41 |
42 |
43 | def affine_transform(src, src_tri, dst_tri, size, flags=cv2.INTER_LINEAR):
44 | M = cv2.getAffineTransform(np.float32(src_tri), np.float32(dst_tri))
45 | dst = cv2.warpAffine(src, M, size, flags=flags, borderMode=cv2.BORDER_REPLICATE)
46 | return dst
47 |
48 |
49 | def morph_seq(
50 | source_img: np.ndarray,
51 | target_img: np.ndarray,
52 | source_landmarks: np.ndarray,
53 | target_landmarks: np.ndarray,
54 | source_mask: np.ndarray,
55 | target_mask: np.ndarray,
56 | comp_list: Dict[int, float],
57 | ):
58 | source_img = np.float32(source_img)
59 | target_img = np.float32(target_img)
60 | source_mask = np.repeat(np.float32(source_mask[..., None]), 3, axis=-1)
61 |
62 | triangulation = Delaunay(source_landmarks).simplices
63 | warped_source_mask = warp_im(
64 | source_mask, source_landmarks, target_landmarks, triangulation, flags=cv2.INTER_NEAREST
65 | )[..., 0]
66 | warped_source = warp_im(source_img, source_landmarks, target_landmarks, triangulation)
67 | # warped_source[warped_source == 0] = target_img[warped_source == 0]
68 |
69 | un_covered_mask = np.zeros_like(target_mask, dtype="bool")
70 | blended = copy.deepcopy(target_img)
71 | for comp, alpha in comp_list.items():
72 | if comp == 9 or comp == 13:
73 | target_comp_mask = np.logical_or(target_mask == 9, target_mask == 13)
74 | source_comp_mask = np.logical_or(warped_source_mask == 9, warped_source_mask == 13)
75 | else:
76 | target_comp_mask = target_mask == comp
77 | source_comp_mask = warped_source_mask == comp
78 |
79 | union_comp_mask = target_comp_mask & source_comp_mask
80 |
81 | blended[union_comp_mask] = (
82 | alpha * warped_source[union_comp_mask] + (1 - alpha) * target_img[union_comp_mask]
83 | )
84 | target_comp_mask[union_comp_mask] = 0
85 | un_covered_mask[target_comp_mask] = 1
86 |
87 | return blended.astype("uint8"), un_covered_mask
88 |
--------------------------------------------------------------------------------
/modeling/__init__.py:
--------------------------------------------------------------------------------
1 | from .model import build_diffuser_model as build_model
2 |
3 | __all__ = ["build_model"]
4 |
--------------------------------------------------------------------------------
/modeling/generate.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import tqdm
4 | from PIL import Image
5 |
6 |
7 | @torch.no_grad()
8 | def generate_image_grid(
9 | net: EDMPrecond,
10 | dest_path,
11 | seed=0,
12 | gridw=4,
13 | gridh=4,
14 | device=torch.device("cuda"),
15 | num_steps=256,
16 | sigma_min=0.002,
17 | sigma_max=80,
18 | rho=7,
19 | S_churn=40,
20 | S_min=0.05,
21 | S_max=50,
22 | S_noise=1,
23 | ):
24 | net.eval()
25 | batch_size = gridw * gridh
26 | torch.manual_seed(seed)
27 |
28 | # Pick latents and labels.
29 | print(f"Generating {batch_size} images...")
30 | latents = torch.randn(
31 | [batch_size, net.img_channels, net.img_resolution, net.img_resolution], device=device
32 | )
33 | class_labels = None
34 | if net.label_dim:
35 | class_labels = torch.eye(net.label_dim, device=device)[
36 | torch.randint(net.label_dim, size=[batch_size], device=device)
37 | ]
38 |
39 | # Adjust noise levels based on what's supported by the network.
40 | sigma_min = max(sigma_min, net.sigma_min)
41 | sigma_max = min(sigma_max, net.sigma_max)
42 |
43 | # Time step discretization.
44 | step_indices = torch.arange(num_steps, dtype=torch.float64, device=device)
45 | t_steps = (
46 | sigma_max ** (1 / rho)
47 | + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))
48 | ) ** rho
49 | t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0
50 |
51 | # Main sampling loop.
52 | x_next = latents.to(torch.float64) * t_steps[0]
53 | for i, (t_cur, t_next) in tqdm.tqdm(
54 | list(enumerate(zip(t_steps[:-1], t_steps[1:]))), unit="step"
55 | ): # 0, ..., N-1
56 | x_cur = x_next
57 |
58 | # Increase noise temporarily.
59 | gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0
60 | t_hat = net.round_sigma(t_cur + gamma * t_cur)
61 | x_hat = x_cur + (t_hat**2 - t_cur**2).sqrt() * S_noise * torch.randn_like(x_cur)
62 |
63 | # Euler step.
64 | denoised = net(x_hat, t_hat, class_labels).to(torch.float64)
65 | d_cur = (x_hat - denoised) / t_hat
66 | x_next = x_hat + (t_next - t_hat) * d_cur
67 |
68 | # Apply 2nd order correction.
69 | if i < num_steps - 1:
70 | denoised = net(x_next, t_next, class_labels).to(torch.float64)
71 | d_prime = (x_next - denoised) / t_next
72 | x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)
73 |
74 | # Save image grid.
75 | print(f'Saving image grid to "{dest_path}"...')
76 | image = (x_next * 127.5 + 128).clip(0, 255).to(torch.uint8)
77 | image = image.reshape(gridh, gridw, *image.shape[1:]).permute(0, 3, 1, 4, 2)
78 | image = image.reshape(gridh * net.img_resolution, gridw * net.img_resolution, net.img_channels)
79 | image = image.cpu().numpy()
80 | Image.fromarray(image, "RGB").save(dest_path)
81 | print("Done.")
82 |
--------------------------------------------------------------------------------
/modeling/loss.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | """Loss functions used in the paper
9 | "Elucidating the Design Space of Diffusion-Based Generative Models"."""
10 |
11 | import torch
12 |
13 | from torch_utils import persistence
14 |
15 | # ----------------------------------------------------------------------------
16 | # Loss function corresponding to the variance preserving (VP) formulation
17 | # from the paper "Score-Based Generative Modeling through Stochastic
18 | # Differential Equations".
19 |
20 |
21 | @persistence.persistent_class
22 | class VPLoss:
23 | def __init__(self, beta_d=19.9, beta_min=0.1, epsilon_t=1e-5):
24 | self.beta_d = beta_d
25 | self.beta_min = beta_min
26 | self.epsilon_t = epsilon_t
27 |
28 | def __call__(self, net, images, labels, augment_pipe=None):
29 | rnd_uniform = torch.rand([images.shape[0], 1, 1, 1], device=images.device)
30 | sigma = self.sigma(1 + rnd_uniform * (self.epsilon_t - 1))
31 | weight = 1 / sigma**2
32 | y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None)
33 | n = torch.randn_like(y) * sigma
34 | D_yn = net(y + n, sigma, labels, augment_labels=augment_labels)
35 | loss = weight * ((D_yn - y) ** 2)
36 | return loss
37 |
38 | def sigma(self, t):
39 | t = torch.as_tensor(t)
40 | return ((0.5 * self.beta_d * (t**2) + self.beta_min * t).exp() - 1).sqrt()
41 |
42 |
43 | # ----------------------------------------------------------------------------
44 | # Loss function corresponding to the variance exploding (VE) formulation
45 | # from the paper "Score-Based Generative Modeling through Stochastic
46 | # Differential Equations".
47 |
48 |
49 | @persistence.persistent_class
50 | class VELoss:
51 | def __init__(self, sigma_min=0.02, sigma_max=100):
52 | self.sigma_min = sigma_min
53 | self.sigma_max = sigma_max
54 |
55 | def __call__(self, net, images, labels, augment_pipe=None):
56 | rnd_uniform = torch.rand([images.shape[0], 1, 1, 1], device=images.device)
57 | sigma = self.sigma_min * ((self.sigma_max / self.sigma_min) ** rnd_uniform)
58 | weight = 1 / sigma**2
59 | y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None)
60 | n = torch.randn_like(y) * sigma
61 | D_yn = net(y + n, sigma, labels, augment_labels=augment_labels)
62 | loss = weight * ((D_yn - y) ** 2)
63 | return loss
64 |
65 |
66 | # ----------------------------------------------------------------------------
67 | # Improved loss function proposed in the paper "Elucidating the Design Space
68 | # of Diffusion-Based Generative Models" (EDM).
69 |
70 |
71 | @persistence.persistent_class
72 | class EDMLoss:
73 | def __init__(self, P_mean=-1.2, P_std=1.2, sigma_data=0.5):
74 | self.P_mean = P_mean
75 | self.P_std = P_std
76 | self.sigma_data = sigma_data
77 |
78 | def __call__(self, net, images, labels=None, augment_pipe=None):
79 | rnd_normal = torch.randn([images.shape[0], 1, 1, 1], device=images.device)
80 | sigma = (rnd_normal * self.P_std + self.P_mean).exp()
81 | weight = (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2
82 | y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None)
83 | n = torch.randn_like(y) * sigma
84 | D_yn = net(y + n, sigma, labels, augment_labels=augment_labels)
85 | loss = weight * ((D_yn - y) ** 2)
86 | return loss
87 |
88 |
89 | # ----------------------------------------------------------------------------
90 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==1.7.0
2 | diffusers==0.33.1
3 | transformers==4.52.3
4 | tabulate==0.9.0
5 | yacs==0.1.8
6 | scipy==1.11.1
7 | pandas==2.2.0
8 | loguru==0.7.2
9 | colorama==0.4.6
10 | scikit-learn==1.3.0
11 | einops==0.6.0
12 | dlib==19.24.2
13 | mediapipe==0.10.9
14 | scikit-image==0.21.0
15 | aim==3.24.0
16 | gradio==5.31.0
17 | gdown
--------------------------------------------------------------------------------
/script/train_text_to_image.sh:
--------------------------------------------------------------------------------
1 | MODEL_NAME="runwayml/stable-diffusion-v1-5"
2 |
3 | accelerate launch --mixed_precision="fp16" sd_training/train_text_to_image.py \
4 | --pretrained_model_name_or_path=$MODEL_NAME \
5 | --train_data_dir data/mtdataset/images \
6 | --train_json_file data/mt_text_anno.json \
7 | --resolution=512 --center_crop --random_flip \
8 | --train_batch_size=1 \
9 | --gradient_accumulation_steps=4 \
10 | --gradient_checkpointing \
11 | --max_train_steps=24000 \
12 | --checkpointing_steps=3000 \
13 | --learning_rate=1e-05 \
14 | --max_grad_norm=1 \
15 | --mixed_precision="fp16" \
16 | --enable_xformers_memory_efficient_attention \
17 | --snr_gamma=5.0 \
18 | --lr_scheduler="constant" --lr_warmup_steps=0 \
19 | --output_dir="runs/sd_512_512" \
20 |
--------------------------------------------------------------------------------
/torch_fidelity/__init__.py:
--------------------------------------------------------------------------------
1 | from torch_fidelity.feature_extractor_base import FeatureExtractorBase
2 | from torch_fidelity.feature_extractor_inceptionv3 import FeatureExtractorInceptionV3
3 | from torch_fidelity.feature_extractor_clip import FeatureExtractorCLIP
4 | from torch_fidelity.generative_model_base import GenerativeModelBase
5 | from torch_fidelity.generative_model_modulewrapper import GenerativeModelModuleWrapper
6 | from torch_fidelity.generative_model_onnx import GenerativeModelONNX
7 | from torch_fidelity.metric_fid import KEY_METRIC_FID
8 | from torch_fidelity.metric_isc import KEY_METRIC_ISC_MEAN, KEY_METRIC_ISC_STD
9 | from torch_fidelity.metric_kid import KEY_METRIC_KID_MEAN, KEY_METRIC_KID_STD
10 | from torch_fidelity.metric_ppl import KEY_METRIC_PPL_MEAN, KEY_METRIC_PPL_STD, KEY_METRIC_PPL_RAW
11 | from torch_fidelity.metric_prc import KEY_METRIC_PRECISION, KEY_METRIC_RECALL, KEY_METRIC_F_SCORE
12 | from torch_fidelity.metrics import calculate_metrics
13 | from torch_fidelity.registry import (
14 | register_dataset,
15 | register_feature_extractor,
16 | register_sample_similarity,
17 | register_noise_source,
18 | register_interpolation,
19 | )
20 | from torch_fidelity.sample_similarity_base import SampleSimilarityBase
21 | from torch_fidelity.sample_similarity_lpips import SampleSimilarityLPIPS
22 | from torch_fidelity.version import __version__
23 |
--------------------------------------------------------------------------------
/torch_fidelity/datasets.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from contextlib import redirect_stdout
3 |
4 | import torch
5 | from PIL import Image
6 | from torch.utils.data import Dataset
7 | from torchvision.datasets import CIFAR10, STL10, CIFAR100
8 | import torchvision.transforms.functional as F
9 |
10 | from torch_fidelity.helpers import vassert
11 |
12 |
13 | class TransformPILtoRGBTensor:
14 | def __call__(self, img):
15 | vassert(type(img) is Image.Image, "Input is not a PIL.Image")
16 | return F.pil_to_tensor(img)
17 |
18 |
19 | class ImagesPathDataset(Dataset):
20 | def __init__(self, files, transforms=None):
21 | self.files = files
22 | self.transforms = TransformPILtoRGBTensor() if transforms is None else transforms
23 |
24 | def __len__(self):
25 | return len(self.files)
26 |
27 | def __getitem__(self, i):
28 | path = self.files[i]
29 | img = Image.open(path).convert("RGB")
30 | img = self.transforms(img)
31 | return img
32 |
33 |
34 | class Cifar10_RGB(CIFAR10):
35 | def __init__(self, *args, **kwargs):
36 | with redirect_stdout(sys.stderr):
37 | super().__init__(*args, **kwargs)
38 |
39 | def __getitem__(self, index):
40 | img, target = super().__getitem__(index)
41 | return img
42 |
43 |
44 | class Cifar100_RGB(CIFAR100):
45 | def __init__(self, *args, **kwargs):
46 | with redirect_stdout(sys.stderr):
47 | super().__init__(*args, **kwargs)
48 |
49 | def __getitem__(self, index):
50 | img, target = super().__getitem__(index)
51 | return img
52 |
53 |
54 | class STL10_RGB(STL10):
55 | def __init__(self, *args, **kwargs):
56 | with redirect_stdout(sys.stderr):
57 | super().__init__(*args, **kwargs)
58 |
59 | def __getitem__(self, index):
60 | img, target = super().__getitem__(index)
61 | return img
62 |
63 |
64 | class RandomlyGeneratedDataset(Dataset):
65 | def __init__(self, num_samples, *dimensions, dtype=torch.uint8, seed=2021):
66 | vassert(dtype == torch.uint8, "Unsupported dtype")
67 | rng_stash = torch.get_rng_state()
68 | try:
69 | torch.manual_seed(seed)
70 | self.imgs = torch.randint(0, 255, (num_samples, *dimensions), dtype=dtype)
71 | finally:
72 | torch.set_rng_state(rng_stash)
73 |
74 | def __len__(self):
75 | return self.imgs.shape[0]
76 |
77 | def __getitem__(self, i):
78 | return self.imgs[i]
79 |
--------------------------------------------------------------------------------
/torch_fidelity/defaults.py:
--------------------------------------------------------------------------------
1 | DEFAULTS = {
2 | "input1": None,
3 | "input2": None,
4 | "input3": None,
5 | "input4": None,
6 | "cuda": True,
7 | "batch_size": 64,
8 | "isc": False,
9 | "fid": False,
10 | "kid": False,
11 | "prc": False,
12 | "ppl": False,
13 | "feature_extractor": None,
14 | "feature_layer_isc": None,
15 | "feature_layer_fid": None,
16 | "feature_layer_kid": None,
17 | "feature_layer_prc": None,
18 | "feature_extractor_weights_path": None,
19 | "feature_extractor_internal_dtype": None,
20 | "feature_extractor_compile": False,
21 | "isc_splits": 10,
22 | "kid_subsets": 100,
23 | "kid_subset_size": 1000,
24 | "kid_kernel": "poly",
25 | "kid_kernel_poly_degree": 3,
26 | "kid_kernel_poly_gamma": None,
27 | "kid_kernel_poly_coef0": 1,
28 | "kid_kernel_rbf_sigma": 10,
29 | "ppl_epsilon": 1e-4,
30 | "ppl_reduction": "mean",
31 | "ppl_sample_similarity": "lpips-vgg16",
32 | "ppl_sample_similarity_resize": 64,
33 | "ppl_sample_similarity_dtype": "uint8",
34 | "ppl_discard_percentile_lower": 1,
35 | "ppl_discard_percentile_higher": 99,
36 | "ppl_z_interp_mode": "lerp",
37 | "prc_neighborhood": 3,
38 | "prc_batch_size": 10000,
39 | "samples_shuffle": True,
40 | "samples_find_deep": False,
41 | "samples_find_ext": "png,jpg,jpeg",
42 | "samples_ext_lossy": "jpg,jpeg",
43 | "samples_resize_and_crop": 0,
44 | "datasets_root": None,
45 | "datasets_download": True,
46 | "cache_root": None,
47 | "cache": True,
48 | "input1_cache_name": None,
49 | "input1_model_z_type": "normal",
50 | "input1_model_z_size": None,
51 | "input1_model_num_classes": 0,
52 | "input1_model_num_samples": None,
53 | "input2_cache_name": None,
54 | "input2_model_z_type": "normal",
55 | "input2_model_z_size": None,
56 | "input2_model_num_classes": 0,
57 | "input2_model_num_samples": None,
58 | "input3_cache_name": None,
59 | "input3_model_z_type": "normal",
60 | "input3_model_z_size": None,
61 | "input3_model_num_classes": 0,
62 | "input3_model_num_samples": None,
63 | "input4_cache_name": None,
64 | "input4_model_z_type": "normal",
65 | "input4_model_z_size": None,
66 | "input4_model_num_classes": 0,
67 | "input4_model_num_samples": None,
68 | "rng_seed": 2020,
69 | "save_cpu_ram": False,
70 | "verbose": True,
71 | }
72 |
--------------------------------------------------------------------------------
/torch_fidelity/deprecations.py:
--------------------------------------------------------------------------------
1 | DEPRECATIONS = {
2 | "kid_degree": {
3 | "new_name": "kid_kernel_poly_degree",
4 | "since": "0.4.0",
5 | "reason": "Supporting various kernel functions, including RBF, to enable computing the metric proposed in https://arxiv.org/pdf/2401.09603.pdf",
6 | },
7 | "kid_gamma": {
8 | "new_name": "kid_kernel_poly_gamma",
9 | "since": "0.4.0",
10 | "reason": "Supporting various kernel functions, including RBF, to enable computing the metric proposed in https://arxiv.org/pdf/2401.09603.pdf",
11 | },
12 | "kid_coef0": {
13 | "new_name": "kid_kernel_poly_coef0",
14 | "since": "0.4.0",
15 | "reason": "Supporting various kernel functions, including RBF, to enable computing the metric proposed in https://arxiv.org/pdf/2401.09603.pdf",
16 | },
17 | }
18 |
--------------------------------------------------------------------------------
/torch_fidelity/feature_extractor_base.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | from torch_fidelity.helpers import vassert
4 |
5 |
6 | class FeatureExtractorBase(nn.Module):
7 | def __init__(self, name, features_list):
8 | """
9 | Base class for feature extractors that can be used in :func:`calculate_metrics`.
10 |
11 | Args:
12 |
13 | name (str): Unique name of the subclassed feature extractor, must be the same as used in
14 | :func:`register_feature_extractor`.
15 |
16 | features_list (list): List of feature names, provided by the subclassed feature extractor.
17 | """
18 | super(FeatureExtractorBase, self).__init__()
19 | vassert(type(name) is str, "Feature extractor name must be a string")
20 | vassert(type(features_list) in (list, tuple), "Wrong features list type")
21 | vassert(
22 | all((a in self.get_provided_features_list() for a in features_list)),
23 | f"Requested features {tuple(features_list)} are not on the list provided by the selected feature extractor "
24 | f"{self.get_provided_features_list()}",
25 | )
26 | vassert(len(features_list) == len(set(features_list)), "Duplicate features requested")
27 | vassert(len(features_list) > 0, "No features requested")
28 | self.name = name
29 | self.features_list = features_list
30 |
31 | def get_name(self):
32 | return self.name
33 |
34 | @staticmethod
35 | def get_provided_features_list():
36 | """
37 | Returns a tuple of feature names, extracted by the subclassed feature extractor.
38 | """
39 | raise NotImplementedError
40 |
41 | @staticmethod
42 | def get_default_feature_layer_for_metric(metric):
43 | """
44 | Returns a default feature name to be used for the metric computation.
45 | """
46 | raise NotImplementedError
47 |
48 | @staticmethod
49 | def can_be_compiled():
50 | """
51 | Indicates whether the subclass can be safely wrapped with torch.compile.
52 | """
53 | raise NotImplementedError
54 |
55 | @staticmethod
56 | def get_dummy_input_for_compile():
57 | """
58 | Returns a dummy input for compilation
59 | """
60 | raise NotImplementedError
61 |
62 | def get_requested_features_list(self):
63 | return self.features_list
64 |
65 | def convert_features_tuple_to_dict(self, features):
66 | # The only compound return type of the forward function amenable to JIT tracing is tuple.
67 | # This function simply helps to recover the mapping.
68 | vassert(
69 | type(features) is tuple and len(features) == len(self.features_list),
70 | "Features must be the output of forward function",
71 | )
72 | return dict(((name, feature) for name, feature in zip(self.features_list, features)))
73 |
74 | def forward(self, input):
75 | """
76 | Returns a tuple of tensors extracted from the `input`, in the same order as they are provided by
77 | `get_provided_features_list()`.
78 | """
79 | raise NotImplementedError
80 |
--------------------------------------------------------------------------------
/torch_fidelity/feature_extractor_dinov2.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import warnings
3 |
4 | import torch
5 | import torchvision
6 |
7 | from torch_fidelity.feature_extractor_base import FeatureExtractorBase
8 | from torch_fidelity.helpers import vassert, text_to_dtype, CleanStderr
9 |
10 | from torch_fidelity.interpolate_compat_tensorflow import interpolate_bilinear_2d_like_tensorflow1x
11 |
12 |
13 | MODEL_METADATA = {
14 | "dinov2-vit-s-14": "dinov2_vits14", # dim=384
15 | "dinov2-vit-b-14": "dinov2_vitb14", # dim=768
16 | "dinov2-vit-l-14": "dinov2_vitl14", # dim=1024
17 | "dinov2-vit-g-14": "dinov2_vitg14", # dim=1536
18 | }
19 |
20 |
21 | class FeatureExtractorDinoV2(FeatureExtractorBase):
22 | INPUT_IMAGE_SIZE = 224
23 |
24 | def __init__(
25 | self,
26 | name,
27 | features_list,
28 | feature_extractor_weights_path=None,
29 | feature_extractor_internal_dtype=None,
30 | **kwargs,
31 | ):
32 | """
33 | DinoV2 feature extractor for 2D RGB 24bit images.
34 |
35 | Args:
36 |
37 | name (str): Unique name of the feature extractor, must be the same as used in
38 | :func:`register_feature_extractor`.
39 |
40 | features_list (list): A list of the requested feature names, which will be produced for each input. This
41 | feature extractor provides the following features:
42 |
43 | - 'dinov2'
44 |
45 | feature_extractor_weights_path (str): Path to the pretrained InceptionV3 model weights in PyTorch format.
46 | Refer to `util_convert_inception_weights` for making your own. Downloads from internet if `None`.
47 |
48 | feature_extractor_internal_dtype (str): dtype to use inside the feature extractor. Specifying it may improve
49 | numerical precision in some cases. Supported values are 'float32' (default), and 'float64'.
50 | """
51 | super(FeatureExtractorDinoV2, self).__init__(name, features_list)
52 | vassert(
53 | feature_extractor_internal_dtype in ("float32", "float64", None),
54 | "Only 32-bit floats are supported for internal dtype of this feature extractor",
55 | )
56 |
57 | vassert(name in MODEL_METADATA, f"Model {name} not found; available models = {list(MODEL_METADATA.keys())}")
58 | self.feature_extractor_internal_dtype = text_to_dtype(feature_extractor_internal_dtype, "float32")
59 |
60 | with CleanStderr(["xFormers not available", "Using cache found in"], sys.stderr), warnings.catch_warnings():
61 | warnings.filterwarnings("ignore", message="xFormers is not available")
62 | if feature_extractor_weights_path is None:
63 | self.model = torch.hub.load("facebookresearch/dinov2", MODEL_METADATA[name])
64 | else:
65 | raise NotImplementedError
66 |
67 | self.to(self.feature_extractor_internal_dtype)
68 | self.requires_grad_(False)
69 | self.eval()
70 |
71 | def forward(self, x):
72 | vassert(torch.is_tensor(x) and x.dtype == torch.uint8, "Expecting image as torch.Tensor with dtype=torch.uint8")
73 | vassert(x.dim() == 4 and x.shape[1] == 3, f"Input is not Bx3xHxW: {x.shape}")
74 |
75 | x = x.to(self.feature_extractor_internal_dtype)
76 | # N x 3 x ? x ?
77 |
78 | x = interpolate_bilinear_2d_like_tensorflow1x(
79 | x,
80 | size=(self.INPUT_IMAGE_SIZE, self.INPUT_IMAGE_SIZE),
81 | align_corners=False,
82 | )
83 | # N x 3 x 224 x 224
84 |
85 | x = torchvision.transforms.functional.normalize(
86 | x,
87 | (255 * 0.485, 255 * 0.456, 255 * 0.406),
88 | (255 * 0.229, 255 * 0.224, 255 * 0.225),
89 | inplace=False,
90 | )
91 | # N x 3 x 224 x 224
92 |
93 | x = self.model(x)
94 |
95 | out = {
96 | "dinov2": x.to(torch.float32),
97 | }
98 |
99 | return tuple(out[a] for a in self.features_list)
100 |
101 | @staticmethod
102 | def get_provided_features_list():
103 | return ("dinov2",)
104 |
105 | @staticmethod
106 | def get_default_feature_layer_for_metric(metric):
107 | return {
108 | "isc": "dinov2",
109 | "fid": "dinov2",
110 | "kid": "dinov2",
111 | "prc": "dinov2",
112 | }[metric]
113 |
114 | @staticmethod
115 | def can_be_compiled():
116 | return True
117 |
118 | @staticmethod
119 | def get_dummy_input_for_compile():
120 | return (torch.rand([1, 3, 4, 4]) * 255).to(torch.uint8)
121 |
--------------------------------------------------------------------------------
/torch_fidelity/feature_extractor_vgg16.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import torchvision
4 |
5 | from torch_fidelity.feature_extractor_base import FeatureExtractorBase
6 | from torch_fidelity.helpers import text_to_dtype, vassert
7 | from torch_fidelity.interpolate_compat_tensorflow import (
8 | interpolate_bilinear_2d_like_tensorflow1x,
9 | )
10 | from torch_fidelity.utils_torchvision import torchvision_load_pretrained_vgg16
11 |
12 |
13 | class FeatureExtractorVGG16(FeatureExtractorBase):
14 | INPUT_IMAGE_SIZE = 224
15 |
16 | def __init__(
17 | self,
18 | name,
19 | features_list,
20 | feature_extractor_weights_path=None,
21 | feature_extractor_internal_dtype=None,
22 | **kwargs,
23 | ):
24 | """
25 | VGG16 feature extractor for 2D RGB 24bit images.
26 |
27 | Args:
28 |
29 | name (str): Unique name of the feature extractor, must be the same as used in
30 | :func:`register_feature_extractor`.
31 |
32 | features_list (list): A list of the requested feature names, which will be produced for each input. This
33 | feature extractor provides the following features:
34 |
35 | - 'fc2'
36 | - 'fc2_relu'
37 |
38 | feature_extractor_weights_path (str): Path to the pretrained InceptionV3 model weights in PyTorch format.
39 | Refer to `util_convert_inception_weights` for making your own. Downloads from internet if `None`.
40 |
41 | feature_extractor_internal_dtype (str): dtype to use inside the feature extractor. Specifying it may improve
42 | numerical precision in some cases. Supported values are 'float32' (default), and 'float64'.
43 | """
44 | super(FeatureExtractorVGG16, self).__init__(name, features_list)
45 | vassert(
46 | feature_extractor_internal_dtype in ("float32", "float64", None),
47 | "Only 32-bit floats are supported for internal dtype of this feature extractor",
48 | )
49 | self.feature_extractor_internal_dtype = text_to_dtype(
50 | feature_extractor_internal_dtype, "float32"
51 | )
52 |
53 | if feature_extractor_weights_path is None:
54 | self.model = torchvision_load_pretrained_vgg16(**kwargs)
55 | else:
56 | state_dict = torch.load(feature_extractor_weights_path)
57 | self.model = torchvision.models.vgg16()
58 | new_state_dict = {}
59 | for k, v in state_dict.items():
60 | new_state_dict[k.replace("model.", "")] = v
61 | print(self.model.load_state_dict(new_state_dict, strict=False))
62 | for cls_tail_id in (6, 5, 4):
63 | del self.model.classifier[cls_tail_id]
64 |
65 | self.to(self.feature_extractor_internal_dtype)
66 | self.requires_grad_(False)
67 | self.eval()
68 |
69 | def forward(self, x):
70 | vassert(
71 | torch.is_tensor(x) and x.dtype == torch.uint8,
72 | "Expecting image as torch.Tensor with dtype=torch.uint8",
73 | )
74 | vassert(x.dim() == 4 and x.shape[1] == 3, f"Input is not Bx3xHxW: {x.shape}")
75 | features = {}
76 | remaining_features = self.features_list.copy()
77 |
78 | x = x.to(self.feature_extractor_internal_dtype)
79 | # N x 3 x ? x ?
80 |
81 | x = interpolate_bilinear_2d_like_tensorflow1x(
82 | x,
83 | size=(self.INPUT_IMAGE_SIZE, self.INPUT_IMAGE_SIZE),
84 | align_corners=False,
85 | )
86 | # N x 3 x 224 x 224
87 |
88 | x = torchvision.transforms.functional.normalize(
89 | x,
90 | (255 * 0.485, 255 * 0.456, 255 * 0.406),
91 | (255 * 0.229, 255 * 0.224, 255 * 0.225),
92 | inplace=False,
93 | )
94 | # N x 3 x 224 x 224
95 |
96 | x = self.model(x)
97 |
98 | if "fc2" in remaining_features:
99 | features["fc2"] = x.to(torch.float32)
100 | remaining_features.remove("fc2")
101 | if len(remaining_features) == 0:
102 | return tuple(features[a] for a in self.features_list)
103 |
104 | features["fc2_relu"] = F.relu(x).to(torch.float32)
105 |
106 | return tuple(features[a] for a in self.features_list)
107 |
108 | @staticmethod
109 | def get_provided_features_list():
110 | return "fc2", "fc2_relu"
111 |
112 | @staticmethod
113 | def get_default_feature_layer_for_metric(metric):
114 | return {
115 | "isc": "fc2_relu",
116 | "fid": "fc2_relu",
117 | "kid": "fc2_relu",
118 | "prc": "fc2_relu",
119 | }[metric]
120 |
121 | @staticmethod
122 | def can_be_compiled():
123 | return True
124 |
125 | @staticmethod
126 | def get_dummy_input_for_compile():
127 | return (torch.rand([1, 3, 4, 4]) * 255).to(torch.uint8)
128 |
--------------------------------------------------------------------------------
/torch_fidelity/generative_model_base.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | import torch
4 |
5 |
6 | class GenerativeModelBase(ABC, torch.nn.Module):
7 | """
8 | Base class for generative models that can be used as inputs in :func:`calculate_metrics`.
9 | """
10 |
11 | @property
12 | @abstractmethod
13 | def z_size(self):
14 | """
15 | Size of the noise dimension of the generative model (positive integer).
16 | """
17 | pass
18 |
19 | @property
20 | @abstractmethod
21 | def z_type(self):
22 | """
23 | Type of the noise used by the generative model (see :ref:`registry ` for a list of preregistered noise
24 | types, see :func:`register_noise_source` for registering a new noise type).
25 | """
26 | pass
27 |
28 | @property
29 | @abstractmethod
30 | def num_classes(self):
31 | """
32 | Number of classes used by a conditional generative model. Must return zero for unconditional models.
33 | """
34 | pass
35 |
--------------------------------------------------------------------------------
/torch_fidelity/generative_model_modulewrapper.py:
--------------------------------------------------------------------------------
1 | import copy
2 |
3 | import torch
4 |
5 | from torch_fidelity.generative_model_base import GenerativeModelBase
6 | from torch_fidelity.helpers import vassert
7 |
8 |
9 | class GenerativeModelModuleWrapper(GenerativeModelBase):
10 | def __init__(self, module, z_size, z_type, num_classes, make_copy=False, make_eval=True, cuda=None):
11 | """
12 | Wraps any generative model :class:`torch.nn.Module`, implements the :class:`GenerativeModelBase` interface, and
13 | provides a few convenience functions.
14 |
15 | Args:
16 |
17 | module (torch.nn.Module): A generative model module, taking a batch of noise samples, and producing
18 | generative samples.
19 |
20 | z_size (int): Size of the noise dimension of the generative model (positive integer).
21 |
22 | z_type (str): Type of the noise used by the generative model (see :ref:`registry ` for a list of
23 | preregistered noise types, see :func:`register_noise_source` for registering a new noise type).
24 |
25 | num_classes (int): Number of classes used by a conditional generative model. Must return zero for
26 | unconditional models.
27 |
28 | make_copy (bool): Makes a copy of the model weights if `True`. Default: `False`.
29 |
30 | make_eval (bool): Switches to :class:`torch.nn.Module` evaluation mode upon construction if `True`. Default:
31 | `True`.
32 |
33 | cuda (bool): Moves the module on a CUDA device if `True`, moves to CPU if `False`, does nothing if `None`.
34 | Default: `None`.
35 | """
36 | super().__init__()
37 | vassert(isinstance(module, torch.nn.Module), "Not an instance of torch.nn.Module")
38 | vassert(type(z_size) is int and z_size > 0, "z_size must be a positive integer")
39 | vassert(z_type in ("normal", "unit", "uniform_0_1"), f"z_type={z_type} not implemented")
40 | vassert(type(num_classes) is int and num_classes >= 0, "num_classes must be a non-negative integer")
41 | self.module = module
42 | if make_copy:
43 | self.module = copy.deepcopy(self.module)
44 | if make_eval:
45 | self.module.eval()
46 | if cuda is not None:
47 | if cuda:
48 | self.module = self.module.cuda()
49 | else:
50 | self.module = self.module.cpu()
51 | self._z_size = z_size
52 | self._z_type = z_type
53 | self._num_classes = num_classes
54 |
55 | @property
56 | def z_size(self):
57 | return self._z_size
58 |
59 | @property
60 | def z_type(self):
61 | return self._z_type
62 |
63 | @property
64 | def num_classes(self):
65 | return self._num_classes
66 |
67 | def forward(self, *args, **kwargs):
68 | return self.module.forward(*args, **kwargs)
69 |
--------------------------------------------------------------------------------
/torch_fidelity/generative_model_onnx.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 | import torch
5 |
6 | from torch_fidelity.generative_model_base import GenerativeModelBase
7 | from torch_fidelity.helpers import vassert
8 |
9 |
10 | class GenerativeModelONNX(GenerativeModelBase):
11 | def __init__(self, path_onnx, z_size, z_type, num_classes):
12 | """
13 | Wraps :obj:`ONNX` generative model, implements the :class:`GenerativeModelBase` interface.
14 |
15 | Args:
16 |
17 | path_onnx (str): Path to a generative model in :obj:`ONNX` format.
18 |
19 | z_size (int): Size of the noise dimension of the generative model (positive integer).
20 |
21 | z_type (str): Type of the noise used by the generative model (see :ref:`registry ` for a list of
22 | preregistered noise types, see :func:`register_noise_source` for registering a new noise type).
23 |
24 | num_classes (int): Number of classes used by a conditional generative model. Must return zero for
25 | unconditional models.
26 | """
27 | super().__init__()
28 | vassert(os.path.isfile(path_onnx), f'Model file not found at "{path_onnx}"')
29 | vassert(type(z_size) is int and z_size > 0, "z_size must be a positive integer")
30 | vassert(z_type in ("normal", "unit", "uniform_0_1"), f"z_type={z_type} not implemented")
31 | vassert(type(num_classes) is int and num_classes >= 0, "num_classes must be a non-negative integer")
32 | try:
33 | import onnxruntime
34 | except ImportError as e:
35 | # This message may be removed if onnxruntime becomes a unified package with embedded CUDA dependencies,
36 | # like for example pytorch
37 | print(
38 | "====================================================================================================\n"
39 | "Loading ONNX models in PyTorch requires ONNX runtime package, which we did not want to include in\n"
40 | "torch_fidelity package requirements.txt. The two relevant pip packages are:\n"
41 | " - onnxruntime (pip install onnxruntime), or\n"
42 | " - onnxruntime-gpu (pip install onnxruntime-gpu).\n"
43 | 'If you choose to install "onnxruntime", you will be able to run inference on CPU only - this may be\n'
44 | 'slow. With "onnxruntime-gpu" speed is not an issue, but at run time you might face CUDA toolkit\n'
45 | "versions incompatibility, which can only be resolved by recompiling onnxruntime-gpu from source.\n"
46 | "Alternatively, use calculate_metrics API and pass an instance of GenerativeModelBase as an input.\n"
47 | "===================================================================================================="
48 | )
49 | raise e
50 | self.ort_session = onnxruntime.InferenceSession(path_onnx)
51 | self.input_names = [a.name for a in self.ort_session.get_inputs()]
52 | self._z_size = z_size
53 | self._z_type = z_type
54 | self._num_classes = num_classes
55 |
56 | @property
57 | def z_size(self):
58 | return self._z_size
59 |
60 | @property
61 | def z_type(self):
62 | return self._z_type
63 |
64 | @property
65 | def num_classes(self):
66 | return self._num_classes
67 |
68 | @staticmethod
69 | def to_numpy(tensor):
70 | return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
71 |
72 | def forward(self, *args):
73 | vassert(
74 | len(args) == len(self.input_names),
75 | f"Number of input arguments {len(args)} does not match ONNX model: {self.input_names}",
76 | )
77 | vassert(all(torch.is_tensor(a) for a in args), "All model inputs must be tensors")
78 | ort_input = {self.input_names[i]: self.to_numpy(args[i]) for i in range(len(args))}
79 | ort_output = self.ort_session.run(None, ort_input)
80 | ort_output = ort_output[0]
81 | vassert(isinstance(ort_output, np.ndarray), "Invalid output of ONNX model")
82 | out = torch.from_numpy(ort_output).to(device=args[0].device)
83 | return out
84 |
--------------------------------------------------------------------------------
/torch_fidelity/helpers.py:
--------------------------------------------------------------------------------
1 | import json
2 | import sys
3 | import warnings
4 |
5 | import torch
6 |
7 | from torch_fidelity.defaults import DEFAULTS
8 | from torch_fidelity.deprecations import DEPRECATIONS
9 |
10 |
11 | def vassert(truecond, message):
12 | if not truecond:
13 | raise ValueError(message)
14 |
15 |
16 | def vprint(verbose, message):
17 | if verbose:
18 | print(message, file=sys.stderr)
19 |
20 |
21 | def get_kwarg(name, kwargs):
22 | return kwargs.get(name, DEFAULTS[name])
23 |
24 |
25 | def json_decode_string(s):
26 | try:
27 | out = json.loads(s)
28 | except json.JSONDecodeError as e:
29 | print(f"Failed to decode JSON string: {s}", file=sys.stderr)
30 | raise
31 | return out
32 |
33 |
34 | def text_to_dtype(name, default=None):
35 | DTYPES = {
36 | "uint8": torch.uint8,
37 | "float32": torch.float32,
38 | "float64": torch.float32,
39 | }
40 | if default in DTYPES:
41 | default = DTYPES[default]
42 | return DTYPES.get(name, default)
43 |
44 |
45 | class CleanStderr:
46 | def __init__(self, filter_phrases, stream=sys.stderr):
47 | self.filter_phrases = filter_phrases
48 | self.stream = stream
49 |
50 | def __enter__(self):
51 | sys.stderr = self
52 |
53 | def __exit__(self, exc_type, exc_value, traceback):
54 | sys.stderr = self.stream
55 |
56 | def write(self, msg):
57 | if not any(phrase in msg for phrase in self.filter_phrases):
58 | self.stream.write(msg)
59 |
60 | def flush(self):
61 | self.stream.flush()
62 |
63 |
64 | def process_deprecations(cfg):
65 | for k, v in cfg.items():
66 | if k not in DEPRECATIONS:
67 | continue
68 | depr = DEPRECATIONS[k]
69 | new_k = depr["new_name"]
70 | cfg[new_k] = v
71 | warnings.warn(
72 | f"Argument \"{k}\" is deprecated since {depr['since']}; use \"{new_k}\" instead. Reason: {depr['reason']}",
73 | FutureWarning,
74 | stacklevel=2,
75 | )
76 | del cfg[k]
77 |
--------------------------------------------------------------------------------
/torch_fidelity/interpolate_compat_tensorflow.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn.functional as F
5 | from torch.nn.modules.utils import _ntuple
6 |
7 |
8 | def interpolate_bilinear_2d_like_tensorflow1x(input, size=None, scale_factor=None, align_corners=None, method="slow"):
9 | r"""Down/up samples the input to either the given :attr:`size` or the given :attr:`scale_factor`
10 |
11 | Epsilon-exact bilinear interpolation as it is implemented in TensorFlow 1.x:
12 | https://github.com/tensorflow/tensorflow/blob/f66daa493e7383052b2b44def2933f61faf196e0/tensorflow/core/kernels/image_resizer_state.h#L41
13 | https://github.com/tensorflow/tensorflow/blob/6795a8c3a3678fb805b6a8ba806af77ddfe61628/tensorflow/core/kernels/resize_bilinear_op.cc#L85
14 | as per proposal:
15 | https://github.com/pytorch/pytorch/issues/10604#issuecomment-465783319
16 |
17 | Related materials:
18 | https://hackernoon.com/how-tensorflows-tf-image-resize-stole-60-days-of-my-life-aba5eb093f35
19 | https://jricheimer.github.io/tensorflow/2019/02/11/resize-confusion/
20 | https://machinethink.net/blog/coreml-upsampling/
21 |
22 | Currently only 2D spatial sampling is supported, i.e. expected inputs are 4-D in shape.
23 |
24 | The input dimensions are interpreted in the form:
25 | `mini-batch x channels x height x width`.
26 |
27 | Args:
28 | input (Tensor): the input tensor
29 | size (Tuple[int, int]): output spatial size.
30 | scale_factor (float or Tuple[float]): multiplier for spatial size. Has to match input size if it is a tuple.
31 | align_corners (bool, optional): Same meaning as in TensorFlow 1.x.
32 | method (str, optional):
33 | 'slow' (1e-4 L_inf error on GPU, bit-exact on CPU, with checkerboard 32x32->299x299), or
34 | 'fast' (1e-3 L_inf error on GPU and CPU, with checkerboard 32x32->299x299)
35 | """
36 | if method not in ("slow", "fast"):
37 | raise ValueError('how_exact can only be one of "slow", "fast"')
38 |
39 | if input.dim() != 4:
40 | raise ValueError("input must be a 4-D tensor")
41 |
42 | if not torch.is_floating_point(input):
43 | raise ValueError("input must be of floating point dtype")
44 |
45 | if size is not None and (type(size) not in (tuple, list) or len(size) != 2):
46 | raise ValueError("size must be a list or a tuple of two elements")
47 |
48 | if align_corners is None:
49 | raise ValueError("align_corners is not specified (use this function for a complete determinism)")
50 |
51 | def _check_size_scale_factor(dim):
52 | if size is None and scale_factor is None:
53 | raise ValueError("either size or scale_factor should be defined")
54 | if size is not None and scale_factor is not None:
55 | raise ValueError("only one of size or scale_factor should be defined")
56 | if scale_factor is not None and isinstance(scale_factor, tuple) and len(scale_factor) != dim:
57 | raise ValueError(
58 | "scale_factor shape must match input shape. "
59 | "Input is {}D, scale_factor size is {}".format(dim, len(scale_factor))
60 | )
61 |
62 | is_tracing = torch._C._get_tracing_state()
63 |
64 | def _output_size(dim):
65 | _check_size_scale_factor(dim)
66 | if size is not None:
67 | if is_tracing:
68 | return [torch.tensor(i) for i in size]
69 | else:
70 | return size
71 | scale_factors = _ntuple(dim)(scale_factor)
72 | # math.floor might return float in py2.7
73 |
74 | # make scale_factor a tensor in tracing so constant doesn't get baked in
75 | if is_tracing:
76 | return [
77 | (torch.floor((input.size(i + 2).float() * torch.tensor(scale_factors[i], dtype=torch.float32)).float()))
78 | for i in range(dim)
79 | ]
80 | else:
81 | return [int(math.floor(float(input.size(i + 2)) * scale_factors[i])) for i in range(dim)]
82 |
83 | def tf_calculate_resize_scale(in_size, out_size):
84 | if align_corners:
85 | if is_tracing:
86 | return (in_size - 1) / (out_size.float() - 1).clamp(min=1)
87 | else:
88 | return (in_size - 1) / max(1, out_size - 1)
89 | else:
90 | if is_tracing:
91 | return in_size / out_size.float()
92 | else:
93 | return in_size / out_size
94 |
95 | out_size = _output_size(2)
96 | scale_x = tf_calculate_resize_scale(input.shape[3], out_size[1])
97 | scale_y = tf_calculate_resize_scale(input.shape[2], out_size[0])
98 |
99 | def resample_using_grid_sample():
100 | grid_x = torch.arange(0, out_size[1], 1, dtype=input.dtype, device=input.device)
101 | grid_x = grid_x * (2 * scale_x / (input.shape[3] - 1)) - 1
102 |
103 | grid_y = torch.arange(0, out_size[0], 1, dtype=input.dtype, device=input.device)
104 | grid_y = grid_y * (2 * scale_y / (input.shape[2] - 1)) - 1
105 |
106 | grid_x = grid_x.view(1, out_size[1]).repeat(out_size[0], 1)
107 | grid_y = grid_y.view(out_size[0], 1).repeat(1, out_size[1])
108 |
109 | grid_xy = torch.cat((grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)), dim=2).unsqueeze(0)
110 | grid_xy = grid_xy.repeat(input.shape[0], 1, 1, 1)
111 |
112 | out = F.grid_sample(input, grid_xy, mode="bilinear", padding_mode="border", align_corners=True)
113 | return out
114 |
115 | def resample_manually():
116 | grid_x = torch.arange(0, out_size[1], 1, dtype=input.dtype, device=input.device)
117 | grid_x = grid_x * torch.tensor(scale_x, dtype=torch.float32)
118 | grid_x_lo = grid_x.long()
119 | grid_x_hi = (grid_x_lo + 1).clamp_max(input.shape[3] - 1)
120 | grid_dx = grid_x - grid_x_lo.float()
121 |
122 | grid_y = torch.arange(0, out_size[0], 1, dtype=input.dtype, device=input.device)
123 | grid_y = grid_y * torch.tensor(scale_y, dtype=torch.float32)
124 | grid_y_lo = grid_y.long()
125 | grid_y_hi = (grid_y_lo + 1).clamp_max(input.shape[2] - 1)
126 | grid_dy = grid_y - grid_y_lo.float()
127 |
128 | # could be improved with index_select
129 | in_00 = input[:, :, grid_y_lo, :][:, :, :, grid_x_lo]
130 | in_01 = input[:, :, grid_y_lo, :][:, :, :, grid_x_hi]
131 | in_10 = input[:, :, grid_y_hi, :][:, :, :, grid_x_lo]
132 | in_11 = input[:, :, grid_y_hi, :][:, :, :, grid_x_hi]
133 |
134 | in_0 = in_00 + (in_01 - in_00) * grid_dx.view(1, 1, 1, out_size[1])
135 | in_1 = in_10 + (in_11 - in_10) * grid_dx.view(1, 1, 1, out_size[1])
136 | out = in_0 + (in_1 - in_0) * grid_dy.view(1, 1, out_size[0], 1)
137 |
138 | return out
139 |
140 | if method == "slow":
141 | out = resample_manually()
142 | else:
143 | out = resample_using_grid_sample()
144 |
145 | return out
146 |
--------------------------------------------------------------------------------
/torch_fidelity/metric_fid.py:
--------------------------------------------------------------------------------
1 | # Functions fid_features_to_statistics and fid_statistics_to_metric are adapted from
2 | # https://github.com/bioinf-jku/TTUR/blob/master/fid.py commit id d4baae8
3 | # Distributed under Apache License 2.0: https://github.com/bioinf-jku/TTUR/blob/master/LICENSE
4 |
5 | import numpy as np
6 | import torch
7 |
8 | from torch_fidelity.helpers import get_kwarg, vprint
9 | from torch_fidelity.utils import (
10 | get_cacheable_input_name,
11 | cache_lookup_one_recompute_on_miss,
12 | extract_featuresdict_from_input_id_cached,
13 | create_feature_extractor,
14 | resolve_feature_extractor,
15 | resolve_feature_layer_for_metric,
16 | )
17 |
18 | KEY_METRIC_FID = "frechet_inception_distance"
19 |
20 |
21 | def fid_features_to_statistics(features):
22 | assert torch.is_tensor(features) and features.dim() == 2
23 | features = features.numpy()
24 | mu = np.mean(features, axis=0)
25 | sigma = np.cov(features, rowvar=False)
26 | return {
27 | "mu": mu,
28 | "sigma": sigma,
29 | }
30 |
31 |
32 | def fid_statistics_to_metric(stat_1, stat_2, verbose):
33 | mu1, sigma1 = stat_1["mu"], stat_1["sigma"]
34 | mu2, sigma2 = stat_2["mu"], stat_2["sigma"]
35 | assert mu1.ndim == 1 and mu1.shape == mu2.shape and mu1.dtype == mu2.dtype
36 | assert sigma1.ndim == 2 and sigma1.shape == sigma2.shape and sigma1.dtype == sigma2.dtype
37 |
38 | diff = mu1 - mu2
39 | tr_covmean = np.sum(np.sqrt(np.linalg.eigvals(sigma1.dot(sigma2)).astype("complex128")).real)
40 | fid = float(diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean)
41 |
42 | out = {KEY_METRIC_FID: fid}
43 |
44 | vprint(verbose, f"Frechet Inception Distance: {out[KEY_METRIC_FID]:.7g}")
45 |
46 | return out
47 |
48 |
49 | def fid_featuresdict_to_statistics(featuresdict, feat_layer_name):
50 | features = featuresdict[feat_layer_name]
51 | statistics = fid_features_to_statistics(features)
52 | return statistics
53 |
54 |
55 | def fid_featuresdict_to_statistics_cached(
56 | featuresdict, cacheable_input_name, feat_extractor, feat_layer_name, **kwargs
57 | ):
58 | def fn_recompute():
59 | return fid_featuresdict_to_statistics(featuresdict, feat_layer_name)
60 |
61 | if cacheable_input_name is not None:
62 | feat_extractor_name = feat_extractor.get_name()
63 | cached_name = f"{cacheable_input_name}-{feat_extractor_name}-stat-fid-{feat_layer_name}"
64 | stat = cache_lookup_one_recompute_on_miss(cached_name, fn_recompute, **kwargs)
65 | else:
66 | stat = fn_recompute()
67 | return stat
68 |
69 |
70 | def fid_input_id_to_statistics(input_id, feat_extractor, feat_layer_name, **kwargs):
71 | featuresdict = extract_featuresdict_from_input_id_cached(input_id, feat_extractor, **kwargs)
72 | return fid_featuresdict_to_statistics(featuresdict, feat_layer_name)
73 |
74 |
75 | def fid_input_id_to_statistics_cached(input_id, feat_extractor, feat_layer_name, **kwargs):
76 | def fn_recompute():
77 | return fid_input_id_to_statistics(input_id, feat_extractor, feat_layer_name, **kwargs)
78 |
79 | cacheable_input_name = get_cacheable_input_name(input_id, **kwargs)
80 |
81 | if cacheable_input_name is not None:
82 | feat_extractor_name = feat_extractor.get_name()
83 | cached_name = f"{cacheable_input_name}-{feat_extractor_name}-stat-fid-{feat_layer_name}"
84 | stat = cache_lookup_one_recompute_on_miss(cached_name, fn_recompute, **kwargs)
85 | else:
86 | stat = fn_recompute()
87 | return stat
88 |
89 |
90 | def fid_inputs_to_metric(feat_extractor, **kwargs):
91 | feat_layer_name = resolve_feature_layer_for_metric("fid", **kwargs)
92 | verbose = get_kwarg("verbose", kwargs)
93 |
94 | vprint(verbose, f"Extracting statistics from input 1")
95 | stats_1 = fid_input_id_to_statistics_cached(1, feat_extractor, feat_layer_name, **kwargs)
96 |
97 | vprint(verbose, f"Extracting statistics from input 2")
98 | stats_2 = fid_input_id_to_statistics_cached(2, feat_extractor, feat_layer_name, **kwargs)
99 |
100 | metric = fid_statistics_to_metric(stats_1, stats_2, get_kwarg("verbose", kwargs))
101 | return metric
102 |
103 |
104 | def calculate_fid(**kwargs):
105 | kwargs["fid"] = True
106 | feature_extractor = resolve_feature_extractor(**kwargs)
107 | feat_layer_name = resolve_feature_layer_for_metric("fid", **kwargs)
108 | feat_extractor = create_feature_extractor(feature_extractor, [feat_layer_name], **kwargs)
109 | metric = fid_inputs_to_metric(feat_extractor, **kwargs)
110 | return metric
111 |
--------------------------------------------------------------------------------
/torch_fidelity/metric_isc.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | from torch_fidelity.helpers import get_kwarg, vprint
5 | from torch_fidelity.utils import (
6 | extract_featuresdict_from_input_id_cached,
7 | create_feature_extractor,
8 | resolve_feature_extractor,
9 | resolve_feature_layer_for_metric,
10 | )
11 |
12 | KEY_METRIC_ISC_MEAN = "inception_score_mean"
13 | KEY_METRIC_ISC_STD = "inception_score_std"
14 |
15 |
16 | def isc_features_to_metric(feature, splits=10, shuffle=True, rng_seed=2020):
17 | assert torch.is_tensor(feature) and feature.dim() == 2
18 | N, C = feature.shape
19 | if shuffle:
20 | rng = np.random.RandomState(rng_seed)
21 | feature = feature[rng.permutation(N), :]
22 | feature = feature.double()
23 |
24 | p = feature.softmax(dim=1)
25 | log_p = feature.log_softmax(dim=1)
26 |
27 | scores = []
28 | for i in range(splits):
29 | p_chunk = p[(i * N // splits) : ((i + 1) * N // splits), :]
30 | log_p_chunk = log_p[(i * N // splits) : ((i + 1) * N // splits), :]
31 | q_chunk = p_chunk.mean(dim=0, keepdim=True)
32 | kl = p_chunk * (log_p_chunk - q_chunk.log())
33 | kl = kl.sum(dim=1).mean().exp().item()
34 | scores.append(kl)
35 |
36 | return {
37 | KEY_METRIC_ISC_MEAN: float(np.mean(scores)),
38 | KEY_METRIC_ISC_STD: float(np.std(scores)),
39 | }
40 |
41 |
42 | def isc_featuresdict_to_metric(featuresdict, feat_layer_name, **kwargs):
43 | features = featuresdict[feat_layer_name]
44 |
45 | out = isc_features_to_metric(
46 | features,
47 | get_kwarg("isc_splits", kwargs),
48 | get_kwarg("samples_shuffle", kwargs),
49 | get_kwarg("rng_seed", kwargs),
50 | )
51 |
52 | vprint(
53 | get_kwarg("verbose", kwargs), f"Inception Score: {out[KEY_METRIC_ISC_MEAN]:.7g} ± {out[KEY_METRIC_ISC_STD]:.7g}"
54 | )
55 |
56 | return out
57 |
58 |
59 | def isc_input_id_to_metric(input_id, feat_extractor, feat_layer_name, **kwargs):
60 | featuresdict = extract_featuresdict_from_input_id_cached(input_id, feat_extractor, **kwargs)
61 | return isc_featuresdict_to_metric(featuresdict, feat_layer_name, **kwargs)
62 |
63 |
64 | def calculate_isc(input_id, **kwargs):
65 | kwargs["isc"] = True
66 | feature_extractor = resolve_feature_extractor(**kwargs)
67 | feat_layer_name = resolve_feature_layer_for_metric("isc", **kwargs)
68 | feat_extractor = create_feature_extractor(feature_extractor, [feat_layer_name], **kwargs)
69 | metric = isc_input_id_to_metric(input_id, feat_extractor, feat_layer_name, **kwargs)
70 | return metric
71 |
--------------------------------------------------------------------------------
/torch_fidelity/metric_kid.py:
--------------------------------------------------------------------------------
1 | # Functions mmd2 and polynomial_kernel are adapted from
2 | # https://github.com/mbinkowski/MMD-GAN/blob/master/gan/compute_scores.py
3 | # Distributed under BSD 3-Clause: https://github.com/mbinkowski/MMD-GAN/blob/master/LICENSE
4 |
5 | import numpy as np
6 | import torch
7 | from tqdm import tqdm
8 |
9 | from torch_fidelity.helpers import get_kwarg, vassert, vprint
10 | from torch_fidelity.utils import (
11 | create_feature_extractor,
12 | extract_featuresdict_from_input_id_cached,
13 | resolve_feature_extractor,
14 | resolve_feature_layer_for_metric,
15 | )
16 |
17 | KEY_METRIC_KID_MEAN = "kernel_inception_distance_mean"
18 | KEY_METRIC_KID_STD = "kernel_inception_distance_std"
19 |
20 |
21 | def mmd2(K_XX, K_XY, K_YY, unit_diagonal=False, mmd_est="unbiased"):
22 | vassert(mmd_est in ("biased", "unbiased", "u-statistic"), "Invalid value of mmd_est")
23 |
24 | m = K_XX.shape[0]
25 | assert K_XX.shape == (m, m)
26 | assert K_XY.shape == (m, m)
27 | assert K_YY.shape == (m, m)
28 |
29 | # Get the various sums of kernels that we'll use
30 | # Kts drop the diagonal, but we don't need to compute them explicitly
31 | if unit_diagonal:
32 | diag_X = diag_Y = 1
33 | sum_diag_X = sum_diag_Y = m
34 | else:
35 | diag_X = np.diagonal(K_XX)
36 | diag_Y = np.diagonal(K_YY)
37 |
38 | sum_diag_X = diag_X.sum()
39 | sum_diag_Y = diag_Y.sum()
40 |
41 | Kt_XX_sums = K_XX.sum(axis=1) - diag_X
42 | Kt_YY_sums = K_YY.sum(axis=1) - diag_Y
43 | K_XY_sums_0 = K_XY.sum(axis=0)
44 |
45 | Kt_XX_sum = Kt_XX_sums.sum()
46 | Kt_YY_sum = Kt_YY_sums.sum()
47 | K_XY_sum = K_XY_sums_0.sum()
48 |
49 | if mmd_est == "biased":
50 | mmd2 = (Kt_XX_sum + sum_diag_X) / (m * m) \
51 | + (Kt_YY_sum + sum_diag_Y) / (m * m) \
52 | - 2 * K_XY_sum / (m * m) # fmt: skip
53 | else:
54 | mmd2 = (Kt_XX_sum + Kt_YY_sum) / (m * (m - 1))
55 | if mmd_est == "unbiased":
56 | mmd2 -= 2 * K_XY_sum / (m * m)
57 | else:
58 | mmd2 -= 2 * (K_XY_sum - np.trace(K_XY)) / (m * (m - 1))
59 |
60 | return mmd2
61 |
62 |
63 | def kernel_poly(X, Y, **kwargs):
64 | degree = get_kwarg("kid_kernel_poly_degree", kwargs)
65 | gamma = get_kwarg("kid_kernel_poly_gamma", kwargs)
66 | coef0 = get_kwarg("kid_kernel_poly_coef0", kwargs)
67 | if gamma is None:
68 | gamma = 1.0 / X.shape[1]
69 | K = (np.matmul(X, Y.T) * gamma + coef0) ** degree
70 | return K
71 |
72 |
73 | def kernel_rbf(X, Y, **kwargs):
74 | sigma = get_kwarg("kid_kernel_rbf_sigma", kwargs)
75 | vassert(sigma is not None and sigma > 0, "kid_kernel_rbf_sigma must be positive")
76 | XX = np.sum(X**2, axis=1)
77 | YY = np.sum(Y**2, axis=1)
78 | XY = np.dot(X, Y.T)
79 | K = np.exp((2 * XY - np.outer(XX, np.ones(YY.shape[0])) - np.outer(np.ones(XX.shape[0]), YY)) / (2 * sigma**2))
80 | return K
81 |
82 |
83 | def kernel_mmd(features_1, features_2, **kwargs):
84 | kernel = get_kwarg("kid_kernel", kwargs)
85 | vassert(kernel in ("poly", "rbf"), "Invalid KID kernel")
86 | kernel = {
87 | "poly": kernel_poly,
88 | "rbf": kernel_rbf,
89 | }[kernel]
90 | k_11 = kernel(features_1, features_1, **kwargs)
91 | k_22 = kernel(features_2, features_2, **kwargs)
92 | k_12 = kernel(features_1, features_2, **kwargs)
93 | return mmd2(k_11, k_12, k_22)
94 |
95 |
96 | def kid_features_to_metric(features_1, features_2, **kwargs):
97 | assert torch.is_tensor(features_1) and features_1.dim() == 2
98 | assert torch.is_tensor(features_2) and features_2.dim() == 2
99 | assert features_1.shape[1] == features_2.shape[1]
100 |
101 | kid_subsets = get_kwarg("kid_subsets", kwargs)
102 | kid_subset_size = get_kwarg("kid_subset_size", kwargs)
103 | verbose = get_kwarg("verbose", kwargs)
104 |
105 | n_samples_1, n_samples_2 = len(features_1), len(features_2)
106 | vassert(
107 | n_samples_1 >= kid_subset_size and n_samples_2 >= kid_subset_size,
108 | f"KID subset size {kid_subset_size} cannot be smaller than the number of samples (input_1: {n_samples_1}, "
109 | f'input_2: {n_samples_2}). Consider using "kid_subset_size" kwarg or "--kid-subset-size" command line key to '
110 | f"proceed.",
111 | )
112 |
113 | features_1 = features_1.cpu().numpy()
114 | features_2 = features_2.cpu().numpy()
115 |
116 | mmds = np.zeros(kid_subsets)
117 | rng = np.random.RandomState(get_kwarg("rng_seed", kwargs))
118 |
119 | for i in tqdm(
120 | range(kid_subsets), disable=not verbose, leave=False, unit="subsets", desc="Kernel Inception Distance"
121 | ):
122 | f1 = features_1[rng.choice(n_samples_1, kid_subset_size, replace=False)]
123 | f2 = features_2[rng.choice(n_samples_2, kid_subset_size, replace=False)]
124 | o = kernel_mmd(f1, f2, **kwargs)
125 | mmds[i] = o
126 |
127 | out = {
128 | KEY_METRIC_KID_MEAN: float(np.mean(mmds)),
129 | KEY_METRIC_KID_STD: float(np.std(mmds)),
130 | }
131 |
132 | vprint(verbose, f"Kernel Inception Distance: {out[KEY_METRIC_KID_MEAN]:.7g} ± {out[KEY_METRIC_KID_STD]:.7g}")
133 |
134 | return out
135 |
136 |
137 | def kid_featuresdict_to_metric(featuresdict_1, featuresdict_2, feat_layer_name, **kwargs):
138 | features_1 = featuresdict_1[feat_layer_name]
139 | features_2 = featuresdict_2[feat_layer_name]
140 | metric = kid_features_to_metric(features_1, features_2, **kwargs)
141 | if metric[KEY_METRIC_KID_MEAN] < 0 and get_kwarg("verbose", kwargs):
142 | print("KID values slightly less than 0 are valid and indicate that distributions are very similar")
143 | return metric
144 |
145 |
146 | def calculate_kid(**kwargs):
147 | kwargs["kid"] = True
148 | feature_extractor = resolve_feature_extractor(**kwargs)
149 | feat_layer_name = resolve_feature_layer_for_metric("kid", **kwargs)
150 | feat_extractor = create_feature_extractor(feature_extractor, [feat_layer_name], **kwargs)
151 | featuresdict_1 = extract_featuresdict_from_input_id_cached(1, feat_extractor, **kwargs)
152 | featuresdict_2 = extract_featuresdict_from_input_id_cached(2, feat_extractor, **kwargs)
153 | metric = kid_featuresdict_to_metric(featuresdict_1, featuresdict_2, feat_layer_name, **kwargs)
154 | return metric
155 |
--------------------------------------------------------------------------------
/torch_fidelity/metric_ppl.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from tqdm import tqdm
4 |
5 | from torch_fidelity.generative_model_base import GenerativeModelBase
6 | from torch_fidelity.helpers import get_kwarg, vassert, vprint
7 | from torch_fidelity.utils import (
8 | sample_random,
9 | batch_interp,
10 | create_sample_similarity,
11 | prepare_input_descriptor_from_input_id,
12 | prepare_input_from_descriptor,
13 | )
14 |
15 | KEY_METRIC_PPL_RAW = "perceptual_path_length_raw"
16 | KEY_METRIC_PPL_MEAN = "perceptual_path_length_mean"
17 | KEY_METRIC_PPL_STD = "perceptual_path_length_std"
18 |
19 |
20 | def calculate_ppl(input_id, **kwargs):
21 | """
22 | Inspired by https://github.com/NVlabs/stylegan/blob/master/metrics/perceptual_path_length.py
23 | """
24 | kwargs["ppl"] = True
25 | batch_size = get_kwarg("batch_size", kwargs)
26 | cuda = get_kwarg("cuda", kwargs)
27 | verbose = get_kwarg("verbose", kwargs)
28 | epsilon = get_kwarg("ppl_epsilon", kwargs)
29 | interp = get_kwarg("ppl_z_interp_mode", kwargs)
30 | reduction = get_kwarg("ppl_reduction", kwargs)
31 | similarity_name = get_kwarg("ppl_sample_similarity", kwargs)
32 | sample_similarity_resize = get_kwarg("ppl_sample_similarity_resize", kwargs)
33 | sample_similarity_dtype = get_kwarg("ppl_sample_similarity_dtype", kwargs)
34 | discard_percentile_lower = get_kwarg("ppl_discard_percentile_lower", kwargs)
35 | discard_percentile_higher = get_kwarg("ppl_discard_percentile_higher", kwargs)
36 |
37 | input_desc = prepare_input_descriptor_from_input_id(input_id, **kwargs)
38 | model = prepare_input_from_descriptor(input_desc, **kwargs)
39 | vassert(
40 | isinstance(model, GenerativeModelBase),
41 | "Input needs to be an instance of GenerativeModelBase, which can be either passed programmatically by wrapping "
42 | "a model with GenerativeModelModuleWrapper, or via command line by specifying a path to a ONNX or PTH (JIT) "
43 | "model and a set of input1_model_* arguments",
44 | )
45 |
46 | if cuda:
47 | model.cuda()
48 |
49 | input_model_num_samples = input_desc["input_model_num_samples"]
50 | input_model_num_classes = model.num_classes
51 | input_model_z_size = model.z_size
52 | input_model_z_type = model.z_type
53 |
54 | vassert(input_model_num_classes >= 0, "Model can be unconditional (0 classes) or conditional (positive)")
55 | vassert(
56 | type(input_model_z_size) is int and input_model_z_size > 0,
57 | 'Dimensionality of generator noise not specified ("input1_model_z_size" argument)',
58 | )
59 | vassert(type(epsilon) is float and epsilon > 0, "Epsilon must be a small positive floating point number")
60 | vassert(type(input_model_num_samples) is int and input_model_num_samples > 0, "Number of samples must be positive")
61 | vassert(reduction in ("none", "mean"), "Reduction must be one of [none, mean]")
62 | vassert(discard_percentile_lower is None or 0 < discard_percentile_lower < 100, "Invalid percentile")
63 | vassert(discard_percentile_higher is None or 0 < discard_percentile_higher < 100, "Invalid percentile")
64 | if discard_percentile_lower is not None and discard_percentile_higher is not None:
65 | vassert(0 < discard_percentile_lower < discard_percentile_higher < 100, "Invalid percentiles")
66 |
67 | sample_similarity = create_sample_similarity(
68 | similarity_name,
69 | sample_similarity_resize=sample_similarity_resize,
70 | sample_similarity_dtype=sample_similarity_dtype,
71 | **kwargs,
72 | )
73 |
74 | is_cond = input_desc["input_model_num_classes"] > 0
75 |
76 | rng = np.random.RandomState(get_kwarg("rng_seed", kwargs))
77 |
78 | lat_e0 = sample_random(rng, (input_model_num_samples, input_model_z_size), input_model_z_type)
79 | lat_e1 = sample_random(rng, (input_model_num_samples, input_model_z_size), input_model_z_type)
80 | lat_e1 = batch_interp(lat_e0, lat_e1, epsilon, interp)
81 |
82 | labels = None
83 | if is_cond:
84 | labels = torch.from_numpy(rng.randint(0, input_model_num_classes, (input_model_num_samples,)))
85 |
86 | distances = []
87 |
88 | with tqdm(
89 | disable=not verbose, leave=False, unit="samples", total=input_model_num_samples, desc="Perceptual Path Length"
90 | ) as t, torch.no_grad():
91 | for begin_id in range(0, input_model_num_samples, batch_size):
92 | end_id = min(begin_id + batch_size, input_model_num_samples)
93 | batch_sz = end_id - begin_id
94 |
95 | batch_lat_e0 = lat_e0[begin_id:end_id]
96 | batch_lat_e1 = lat_e1[begin_id:end_id]
97 | if is_cond:
98 | batch_labels = labels[begin_id:end_id]
99 |
100 | if cuda:
101 | batch_lat_e0 = batch_lat_e0.cuda(non_blocking=True)
102 | batch_lat_e1 = batch_lat_e1.cuda(non_blocking=True)
103 | if is_cond:
104 | batch_labels = batch_labels.cuda(non_blocking=True)
105 |
106 | if is_cond:
107 | rgb_e01 = model.forward(
108 | torch.cat((batch_lat_e0, batch_lat_e1), dim=0),
109 | torch.cat((batch_labels, batch_labels), dim=0),
110 | )
111 | else:
112 | rgb_e01 = model.forward(torch.cat((batch_lat_e0, batch_lat_e1), dim=0))
113 | rgb_e0, rgb_e1 = rgb_e01.chunk(2)
114 |
115 | sim = sample_similarity(rgb_e0, rgb_e1)
116 | dist_lat_e01 = sim / (epsilon**2)
117 | distances.append(dist_lat_e01.cpu().numpy())
118 |
119 | t.update(batch_sz)
120 |
121 | distances = np.concatenate(distances, axis=0)
122 |
123 | cond, lo, hi = None, None, None
124 | if discard_percentile_lower is not None:
125 | lo = np.percentile(distances, discard_percentile_lower, interpolation="lower")
126 | cond = lo <= distances
127 | if discard_percentile_higher is not None:
128 | hi = np.percentile(distances, discard_percentile_higher, interpolation="higher")
129 | cond = np.logical_and(cond, distances <= hi)
130 | if cond is not None:
131 | distances = np.extract(cond, distances)
132 |
133 | out = {
134 | KEY_METRIC_PPL_MEAN: float(np.mean(distances)),
135 | KEY_METRIC_PPL_STD: float(np.std(distances)),
136 | }
137 | if reduction == "none":
138 | out[KEY_METRIC_PPL_RAW] = distances
139 |
140 | vprint(verbose, f"Perceptual Path Length: {out[KEY_METRIC_PPL_MEAN]:.7g} ± {out[KEY_METRIC_PPL_STD]:.7g}")
141 |
142 | return out
143 |
--------------------------------------------------------------------------------
/torch_fidelity/metric_prc.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch_fidelity.helpers import get_kwarg, vprint
4 | from torch_fidelity.utils import (
5 | create_feature_extractor,
6 | extract_featuresdict_from_input_id_cached,
7 | resolve_feature_extractor,
8 | resolve_feature_layer_for_metric,
9 | )
10 |
11 | KEY_METRIC_PRECISION = "precision"
12 | KEY_METRIC_RECALL = "recall"
13 | KEY_METRIC_F_SCORE = "f_score"
14 |
15 |
16 | def calc_cdist_part(features_1, features_2, batch_size=10000):
17 | dists = []
18 | for feat2_batch in features_2.split(batch_size):
19 | dists.append(torch.cdist(features_1, feat2_batch).cpu())
20 | return torch.cat(dists, dim=1)
21 |
22 |
23 | def calculate_precision_recall_part(features_1, features_2, neighborhood=3, batch_size=10000):
24 | # Precision
25 | dist_nn_1 = []
26 | for feat_1_batch in features_1.split(batch_size):
27 | dist_nn_1.append(calc_cdist_part(feat_1_batch, features_1, batch_size).kthvalue(neighborhood + 1).values)
28 | dist_nn_1 = torch.cat(dist_nn_1)
29 | precision = []
30 | for feat_2_batch in features_2.split(batch_size):
31 | dist_2_1_batch = calc_cdist_part(feat_2_batch, features_1, batch_size)
32 | precision.append((dist_2_1_batch <= dist_nn_1).any(dim=1).float())
33 | precision = torch.cat(precision).mean().item()
34 | # Recall
35 | dist_nn_2 = []
36 | for feat_2_batch in features_2.split(batch_size):
37 | dist_nn_2.append(calc_cdist_part(feat_2_batch, features_2, batch_size).kthvalue(neighborhood + 1).values)
38 | dist_nn_2 = torch.cat(dist_nn_2)
39 | recall = []
40 | for feat_1_batch in features_1.split(batch_size):
41 | dist_1_2_batch = calc_cdist_part(feat_1_batch, features_2, batch_size)
42 | recall.append((dist_1_2_batch <= dist_nn_2).any(dim=1).float())
43 | recall = torch.cat(recall).mean().item()
44 | return precision, recall
45 |
46 |
47 | def calc_cdist_full(features_1, features_2, batch_size=10000):
48 | dists = []
49 | for feat1_batch in features_1.split(batch_size):
50 | dists_batch = []
51 | for feat2_batch in features_2.split(batch_size):
52 | dists_batch.append(torch.cdist(feat1_batch, feat2_batch).cpu())
53 | dists.append(torch.cat(dists_batch, dim=1))
54 | return torch.cat(dists, dim=0)
55 |
56 |
57 | def calculate_precision_recall_full(features_1, features_2, neighborhood=3, batch_size=10000):
58 | dist_nn_1 = calc_cdist_full(features_1, features_1, batch_size).kthvalue(neighborhood + 1).values
59 | dist_nn_2 = calc_cdist_full(features_2, features_2, batch_size).kthvalue(neighborhood + 1).values
60 | dist_2_1 = calc_cdist_full(features_2, features_1, batch_size)
61 | dist_1_2 = dist_2_1.T
62 | # Precision
63 | precision = (dist_2_1 <= dist_nn_1).any(dim=1).float().mean().item()
64 | # Recall
65 | recall = (dist_1_2 <= dist_nn_2).any(dim=1).float().mean().item()
66 | return precision, recall
67 |
68 |
69 | def prc_features_to_metric(features_1, features_2, **kwargs):
70 | # Convention: features_1 is REAL, features_2 is GENERATED. This important for the notion of precision/recall only.
71 | assert torch.is_tensor(features_1) and features_1.dim() == 2
72 | assert torch.is_tensor(features_2) and features_2.dim() == 2
73 | assert features_1.shape[1] == features_2.shape[1]
74 |
75 | neighborhood = get_kwarg("prc_neighborhood", kwargs)
76 | batch_size = get_kwarg("prc_batch_size", kwargs)
77 | save_cpu_ram = get_kwarg("save_cpu_ram", kwargs)
78 | verbose = get_kwarg("verbose", kwargs)
79 |
80 | calculate_precision_recall_fn = calculate_precision_recall_part if save_cpu_ram else calculate_precision_recall_full
81 | precision, recall = calculate_precision_recall_fn(features_1, features_2, neighborhood, batch_size)
82 | f_score = 2 * precision * recall / max(1e-5, precision + recall)
83 |
84 | out = {
85 | KEY_METRIC_PRECISION: precision,
86 | KEY_METRIC_RECALL: recall,
87 | KEY_METRIC_F_SCORE: f_score,
88 | }
89 |
90 | vprint(verbose, f"Precision: {out[KEY_METRIC_PRECISION]:.7g}")
91 | vprint(verbose, f"Recall: {out[KEY_METRIC_RECALL]:.7g}")
92 | vprint(verbose, f"F-score: {out[KEY_METRIC_F_SCORE]:.7g}")
93 |
94 | return out
95 |
96 |
97 | def prc_featuresdict_to_metric(featuresdict_1, featuresdict_2, feat_layer_name, **kwargs):
98 | features_1 = featuresdict_1[feat_layer_name]
99 | features_2 = featuresdict_2[feat_layer_name]
100 | metric = prc_features_to_metric(features_1, features_2, **kwargs)
101 | return metric
102 |
103 |
104 | def calculate_prc(**kwargs):
105 | kwargs["prc"] = True
106 | feature_extractor = resolve_feature_extractor(**kwargs)
107 | feat_layer_name = resolve_feature_layer_for_metric("prc", **kwargs)
108 | feat_extractor = create_feature_extractor(feature_extractor, [feat_layer_name], **kwargs)
109 | featuresdict_1 = extract_featuresdict_from_input_id_cached(1, feat_extractor, **kwargs)
110 | featuresdict_2 = extract_featuresdict_from_input_id_cached(2, feat_extractor, **kwargs)
111 | metric = prc_featuresdict_to_metric(featuresdict_1, featuresdict_2, feat_layer_name, **kwargs)
112 | return metric
113 |
--------------------------------------------------------------------------------
/torch_fidelity/noise.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def batch_normalize_last_dim(v, eps=1e-7):
5 | return v / (v**2).sum(dim=-1, keepdim=True).sqrt().clamp_min(eps)
6 |
7 |
8 | def random_normal(rng, shape):
9 | return torch.from_numpy(rng.randn(*shape)).float()
10 |
11 |
12 | def random_unit(rng, shape):
13 | return batch_normalize_last_dim(torch.from_numpy(rng.rand(*shape)).float())
14 |
15 |
16 | def random_uniform_0_1(rng, shape):
17 | return torch.from_numpy(rng.rand(*shape)).float()
18 |
19 |
20 | def batch_lerp(a, b, t):
21 | return a + (b - a) * t
22 |
23 |
24 | def batch_slerp_any(a, b, t, eps=1e-7):
25 | assert torch.is_tensor(a) and torch.is_tensor(b) and a.dim() >= 2 and a.shape == b.shape
26 | ndims, N = a.dim() - 1, a.shape[-1]
27 | a_1 = batch_normalize_last_dim(a, eps)
28 | b_1 = batch_normalize_last_dim(b, eps)
29 | d = (a_1 * b_1).sum(dim=-1, keepdim=True)
30 | mask_zero = (a_1.norm(dim=-1, keepdim=True) < eps) | (b_1.norm(dim=-1, keepdim=True) < eps)
31 | mask_collinear = (d > 1 - eps) | (d < -1 + eps)
32 | mask_lerp = (mask_zero | mask_collinear).repeat([1 for _ in range(ndims)] + [N])
33 | omega = d.acos()
34 | denom = omega.sin().clamp_min(eps)
35 | coef_a = ((1 - t) * omega).sin() / denom
36 | coef_b = (t * omega).sin() / denom
37 | out = coef_a * a + coef_b * b
38 | out[mask_lerp] = batch_lerp(a, b, t)[mask_lerp]
39 | return out
40 |
41 |
42 | def batch_slerp_unit(a, b, t, eps=1e-7):
43 | out = batch_slerp_any(a, b, t, eps)
44 | out = batch_normalize_last_dim(out, eps)
45 | return out
46 |
--------------------------------------------------------------------------------
/torch_fidelity/registry.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from torch_fidelity.datasets import TransformPILtoRGBTensor, Cifar10_RGB, Cifar100_RGB, STL10_RGB
4 | from torch_fidelity.feature_extractor_base import FeatureExtractorBase
5 | from torch_fidelity.feature_extractor_clip import FeatureExtractorCLIP
6 | from torch_fidelity.feature_extractor_dinov2 import FeatureExtractorDinoV2
7 | from torch_fidelity.feature_extractor_inceptionv3 import FeatureExtractorInceptionV3
8 | from torch_fidelity.feature_extractor_vgg16 import FeatureExtractorVGG16
9 | from torch_fidelity.helpers import vassert
10 | from torch_fidelity.noise import (
11 | random_normal,
12 | random_unit,
13 | random_uniform_0_1,
14 | batch_lerp,
15 | batch_slerp_any,
16 | batch_slerp_unit,
17 | )
18 | from torch_fidelity.sample_similarity_base import SampleSimilarityBase
19 | from torch_fidelity.sample_similarity_lpips import SampleSimilarityLPIPS
20 |
21 | DATASETS_REGISTRY = dict()
22 | FEATURE_EXTRACTORS_REGISTRY = dict()
23 | SAMPLE_SIMILARITY_REGISTRY = dict()
24 | NOISE_SOURCE_REGISTRY = dict()
25 | INTERPOLATION_REGISTRY = dict()
26 |
27 |
28 | def register_dataset(name, fn_create):
29 | """
30 | Registers a new input source.
31 |
32 | Args:
33 |
34 | name (str): Unique name of the input source.
35 |
36 | fn_create (callable): A constructor of a :class:`~torch:torch.utils.data.Dataset` instance. Callable arguments:
37 |
38 | - `root` (str): Location where the dataset files may be downloaded.
39 | - `download` (bool): Whether to perform downloading or rely on the cached version.
40 | """
41 | vassert(type(name) is str, "Dataset must be given a name")
42 | vassert(name.strip() == name, "Name must not have leading or trailing whitespaces")
43 | vassert(os.path.sep not in name, "Name must not contain path delimiters (slash/backslash)")
44 | vassert(name not in DATASETS_REGISTRY, f'Dataset "{name}" is already registered')
45 | vassert(
46 | callable(fn_create),
47 | "Dataset must be provided as a callable (function, lambda) with 2 bool arguments: root, download",
48 | )
49 | DATASETS_REGISTRY[name] = fn_create
50 |
51 |
52 | def register_feature_extractor(name, cls):
53 | """
54 | Registers a new feature extractor.
55 |
56 | Args:
57 |
58 | name (str): Unique name of the feature extractor.
59 |
60 | cls (FeatureExtractorBase): Instance of :class:`FeatureExtractorBase`, implementing a new feature extractor.
61 | """
62 | vassert(type(name) is str, "Feature extractor must be given a name")
63 | vassert(name.strip() == name, "Name must not have leading or trailing whitespaces")
64 | vassert(os.path.sep not in name, "Name must not contain path delimiters (slash/backslash)")
65 | vassert(name not in FEATURE_EXTRACTORS_REGISTRY, f'Feature extractor "{name}" is already registered')
66 | vassert(
67 | issubclass(cls, FeatureExtractorBase), "Feature extractor class must be subclassed from FeatureExtractorBase"
68 | )
69 | FEATURE_EXTRACTORS_REGISTRY[name] = cls
70 |
71 |
72 | def register_sample_similarity(name, cls):
73 | """
74 | Registers a new sample similarity measure.
75 |
76 | Args:
77 |
78 | name (str): Unique name of the sample similarity measure.
79 |
80 | cls (SampleSimilarityBase): Instance of :class:`SampleSimilarityBase`, implementing a new sample similarity
81 | measure.
82 | """
83 | vassert(type(name) is str, "Sample similarity must be given a name")
84 | vassert(name.strip() == name, "Name must not have leading or trailing whitespaces")
85 | vassert(os.path.sep not in name, "Name must not contain path delimiters (slash/backslash)")
86 | vassert(name not in SAMPLE_SIMILARITY_REGISTRY, f'Sample similarity "{name}" is already registered')
87 | vassert(
88 | issubclass(cls, SampleSimilarityBase), "Sample similarity class must be subclassed from SampleSimilarityBase"
89 | )
90 | SAMPLE_SIMILARITY_REGISTRY[name] = cls
91 |
92 |
93 | def register_noise_source(name, fn_generate):
94 | """
95 | Registers a new noise source, which can generate samples to be used as inputs to generative models.
96 |
97 | Args:
98 |
99 | name (str): Unique name of the noise source.
100 |
101 | fn_generate (callable): Generator of a random samples of specified type and shape. Callable arguments:
102 |
103 | - `rng` (numpy.random.RandomState): random number generator state, initialized with \
104 | :paramref:`~calculate_metrics.seed`.
105 | - `shape` (torch.Size): shape of the tensor of random samples.
106 | """
107 | vassert(type(name) is str, "Noise source must be given a name")
108 | vassert(name.strip() == name, "Name must not have leading or trailing whitespaces")
109 | vassert(os.path.sep not in name, "Name must not contain path delimiters (slash/backslash)")
110 | vassert(name not in NOISE_SOURCE_REGISTRY, f'Noise source "{name}" is already registered')
111 | vassert(
112 | callable(fn_generate),
113 | "Noise source must be provided as a callable (function, lambda) with 2 arguments: rng, shape",
114 | )
115 | NOISE_SOURCE_REGISTRY[name] = fn_generate
116 |
117 |
118 | def register_interpolation(name, fn_interpolate):
119 | """
120 | Registers a new sample interpolation method.
121 |
122 | Args:
123 |
124 | name (str): Unique name of the interpolation method.
125 |
126 | fn_interpolate (callable): Sample interpolation function. Callable arguments:
127 |
128 | - `a` (torch.Tensor): batch of the first endpoint samples.
129 | - `b` (torch.Tensor): batch of the second endpoint samples.
130 | - `t` (float): interpolation coefficient in the range [0,1].
131 | """
132 | vassert(type(name) is str, "Interpolation must be given a name")
133 | vassert(name.strip() == name, "Name must not have leading or trailing whitespaces")
134 | vassert(os.path.sep not in name, "Name must not contain path delimiters (slash/backslash)")
135 | vassert(name not in INTERPOLATION_REGISTRY, f'Interpolation "{name}" is already registered')
136 | vassert(
137 | callable(fn_interpolate),
138 | "Interpolation must be provided as a callable (function, lambda) with 3 arguments: a, b, t",
139 | )
140 | INTERPOLATION_REGISTRY[name] = fn_interpolate
141 |
142 |
143 | register_dataset(
144 | "cifar10-train",
145 | lambda root, download: Cifar10_RGB(root, train=True, transform=TransformPILtoRGBTensor(), download=download),
146 | )
147 | register_dataset(
148 | "cifar10-val",
149 | lambda root, download: Cifar10_RGB(root, train=False, transform=TransformPILtoRGBTensor(), download=download),
150 | )
151 | register_dataset(
152 | "cifar100-train",
153 | lambda root, download: Cifar100_RGB(root, train=True, transform=TransformPILtoRGBTensor(), download=download),
154 | )
155 | register_dataset(
156 | "cifar100-val",
157 | lambda root, download: Cifar100_RGB(root, train=False, transform=TransformPILtoRGBTensor(), download=download),
158 | )
159 | register_dataset(
160 | "stl10-train",
161 | lambda root, download: STL10_RGB(root, split="train", transform=TransformPILtoRGBTensor(), download=download),
162 | )
163 | register_dataset(
164 | "stl10-test",
165 | lambda root, download: STL10_RGB(root, split="test", transform=TransformPILtoRGBTensor(), download=download),
166 | )
167 | register_dataset(
168 | "stl10-unlabeled",
169 | lambda root, download: STL10_RGB(root, split="unlabeled", transform=TransformPILtoRGBTensor(), download=download),
170 | )
171 |
172 | register_feature_extractor("inception-v3-compat", FeatureExtractorInceptionV3)
173 |
174 | register_feature_extractor("vgg16", FeatureExtractorVGG16)
175 |
176 | register_feature_extractor("clip-rn50", FeatureExtractorCLIP)
177 | register_feature_extractor("clip-rn101", FeatureExtractorCLIP)
178 | register_feature_extractor("clip-rn50x4", FeatureExtractorCLIP)
179 | register_feature_extractor("clip-rn50x16", FeatureExtractorCLIP)
180 | register_feature_extractor("clip-rn50x64", FeatureExtractorCLIP)
181 | register_feature_extractor("clip-vit-b-32", FeatureExtractorCLIP)
182 | register_feature_extractor("clip-vit-b-16", FeatureExtractorCLIP)
183 | register_feature_extractor("clip-vit-l-14", FeatureExtractorCLIP)
184 | register_feature_extractor("clip-vit-l-14-336px", FeatureExtractorCLIP)
185 |
186 | register_feature_extractor("dinov2-vit-s-14", FeatureExtractorDinoV2)
187 | register_feature_extractor("dinov2-vit-b-14", FeatureExtractorDinoV2)
188 | register_feature_extractor("dinov2-vit-l-14", FeatureExtractorDinoV2)
189 | register_feature_extractor("dinov2-vit-g-14", FeatureExtractorDinoV2)
190 |
191 | register_sample_similarity("lpips-vgg16", SampleSimilarityLPIPS)
192 |
193 | register_noise_source("normal", random_normal)
194 | register_noise_source("unit", random_unit)
195 | register_noise_source("uniform_0_1", random_uniform_0_1)
196 |
197 | register_interpolation("lerp", batch_lerp)
198 | register_interpolation("slerp_any", batch_slerp_any)
199 | register_interpolation("slerp_unit", batch_slerp_unit)
200 |
--------------------------------------------------------------------------------
/torch_fidelity/sample_similarity_base.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | from torch_fidelity.helpers import vassert
4 |
5 |
6 | class SampleSimilarityBase(nn.Module):
7 | def __init__(self, name):
8 | """
9 | Base class for samples similarity measures that can be used in :func:`calculate_metrics`.
10 |
11 | Args:
12 |
13 | name (str): Unique name of the subclassed sample similarity measure, must be the same as used in
14 | :func:`register_sample_similarity`.
15 | """
16 | super(SampleSimilarityBase, self).__init__()
17 | vassert(type(name) is str, "Sample similarity name must be a string")
18 | self.name = name
19 |
20 | def get_name(self):
21 | return self.name
22 |
23 | def forward(self, *args):
24 | """
25 | Returns the value of sample similarity between the inputs.
26 | """
27 | raise NotImplementedError
28 |
--------------------------------------------------------------------------------
/torch_fidelity/sample_similarity_lpips.py:
--------------------------------------------------------------------------------
1 | # Adaptation of the following sources:
2 | # https://github.com/richzhang/PerceptualSimilarity/blob/master/lpips/pretrained_networks.py
3 | # https://github.com/richzhang/PerceptualSimilarity/blob/master/lpips/lpips.py
4 | # Distributed under BSD 2-Clause: https://github.com/richzhang/PerceptualSimilarity/blob/master/LICENSE
5 | import sys
6 | from contextlib import redirect_stdout
7 |
8 | import torch
9 | import torch.nn as nn
10 | from torch.hub import load_state_dict_from_url
11 |
12 | from torch_fidelity.helpers import vassert, text_to_dtype
13 | from torch_fidelity.sample_similarity_base import SampleSimilarityBase
14 | from torch_fidelity.utils_torchvision import torchvision_load_pretrained_vgg16
15 |
16 | # VGG16 LPIPS original weights re-uploaded from the following location:
17 | # https://github.com/richzhang/PerceptualSimilarity/blob/master/lpips/weights/v0.1/vgg.pth
18 | # Distributed under BSD 2-Clause: https://github.com/richzhang/PerceptualSimilarity/blob/master/LICENSE
19 | URL_VGG16_LPIPS = "https://github.com/toshas/torch-fidelity/releases/download/v0.2.0/weights-vgg16-lpips.pth"
20 |
21 |
22 | class VGG16features(torch.nn.Module):
23 | def __init__(self):
24 | super().__init__()
25 | vgg_pretrained_features = torchvision_load_pretrained_vgg16().features
26 | self.slice1 = torch.nn.Sequential()
27 | self.slice2 = torch.nn.Sequential()
28 | self.slice3 = torch.nn.Sequential()
29 | self.slice4 = torch.nn.Sequential()
30 | self.slice5 = torch.nn.Sequential()
31 | self.N_slices = 5
32 | for x in range(4):
33 | self.slice1.add_module(str(x), vgg_pretrained_features[x])
34 | for x in range(4, 9):
35 | self.slice2.add_module(str(x), vgg_pretrained_features[x])
36 | for x in range(9, 16):
37 | self.slice3.add_module(str(x), vgg_pretrained_features[x])
38 | for x in range(16, 23):
39 | self.slice4.add_module(str(x), vgg_pretrained_features[x])
40 | for x in range(23, 30):
41 | self.slice5.add_module(str(x), vgg_pretrained_features[x])
42 | self.eval()
43 | for param in self.parameters():
44 | param.requires_grad = False
45 |
46 | def forward(self, X):
47 | h = self.slice1(X)
48 | h_relu1_2 = h
49 | h = self.slice2(h)
50 | h_relu2_2 = h
51 | h = self.slice3(h)
52 | h_relu3_3 = h
53 | h = self.slice4(h)
54 | h_relu4_3 = h
55 | h = self.slice5(h)
56 | h_relu5_3 = h
57 | return h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3
58 |
59 |
60 | def spatial_average(in_tensor):
61 | return in_tensor.mean([2, 3]).squeeze(1)
62 |
63 |
64 | def normalize_tensor(in_features, eps=1e-10):
65 | norm_factor = torch.sqrt(torch.sum(in_features**2, dim=1, keepdim=True))
66 | return in_features / (norm_factor + eps)
67 |
68 |
69 | class NetLinLayer(nn.Module):
70 | def __init__(self, chn_in, chn_out=1, use_dropout=False):
71 | super(NetLinLayer, self).__init__()
72 | layers = (
73 | [
74 | nn.Dropout(),
75 | ]
76 | if use_dropout
77 | else []
78 | )
79 | layers += [
80 | nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),
81 | ]
82 | self.model = nn.Sequential(*layers)
83 |
84 |
85 | class SampleSimilarityLPIPS(SampleSimilarityBase):
86 | def __init__(self, name, sample_similarity_resize=None, sample_similarity_dtype=None, **kwargs):
87 | """
88 | LPIPS sample similarity measure for 2D RGB 24bit images.
89 |
90 | Args:
91 |
92 | name (str): Unique name of the sample similarity measure, must be the same as used in
93 | :func:`register_sample_similarity`.
94 |
95 | sample_similarity_resize (int or None): Resizes inputs to this size if set, keeps as is if `None`.
96 |
97 | sample_similarity_dtype (str): Coerces tensor dtype to one of the following: 'uint8', 'float32'.
98 | This is useful when the inputs are generated by a generative model, to ensure the proper data range and
99 | quantization.
100 | """
101 | super(SampleSimilarityLPIPS, self).__init__(name)
102 | self.sample_similarity_resize = sample_similarity_resize
103 | self.sample_similarity_dtype = sample_similarity_dtype
104 | self.chns = [64, 128, 256, 512, 512]
105 | self.L = len(self.chns)
106 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=True)
107 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=True)
108 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=True)
109 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=True)
110 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=True)
111 | self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
112 | with redirect_stdout(sys.stderr):
113 | state_dict = load_state_dict_from_url(URL_VGG16_LPIPS, map_location="cpu", progress=True)
114 | self.load_state_dict(state_dict)
115 | self.net = VGG16features()
116 | self.eval()
117 | for param in self.parameters():
118 | param.requires_grad = False
119 |
120 | @staticmethod
121 | def normalize(x):
122 | # torchvision values in range [0,1] mean = [0.485, 0.456, 0.406] and std = [0.229, 0.224, 0.225]
123 | mean_rescaled = (1 + torch.tensor([-0.030, -0.088, -0.188], device=x.device)[None, :, None, None]) * 255 / 2
124 | inv_std_rescaled = 2 / (torch.tensor([0.458, 0.448, 0.450], device=x.device)[None, :, None, None] * 255)
125 | x = (x.float() - mean_rescaled) * inv_std_rescaled
126 | return x
127 |
128 | @staticmethod
129 | def resize(x, size):
130 | if x.shape[-1] > size and x.shape[-2] > size:
131 | x = torch.nn.functional.interpolate(x, (size, size), mode="area")
132 | else:
133 | x = torch.nn.functional.interpolate(x, (size, size), mode="bilinear", align_corners=False)
134 | return x
135 |
136 | def forward(self, in0, in1):
137 | vassert(torch.is_tensor(in0) and torch.is_tensor(in1), "Inputs must be torch tensors")
138 | vassert(in0.dim() == 4 and in0.shape[1] == 3, "Input 0 is not Bx3xHxW")
139 | vassert(in1.dim() == 4 and in1.shape[1] == 3, "Input 1 is not Bx3xHxW")
140 | if self.sample_similarity_dtype is not None:
141 | dtype = text_to_dtype(self.sample_similarity_dtype, None)
142 | vassert(
143 | dtype is not None and in0.dtype == dtype and in1.dtype == dtype, f"Unexpected input dtype ({in0.dtype})"
144 | )
145 | in0_input = self.normalize(in0)
146 | in1_input = self.normalize(in1)
147 |
148 | if self.sample_similarity_resize is not None:
149 | in0_input = self.resize(in0_input, self.sample_similarity_resize)
150 | in1_input = self.resize(in1_input, self.sample_similarity_resize)
151 |
152 | outs0 = self.net.forward(in0_input)
153 | outs1 = self.net.forward(in1_input)
154 |
155 | feats0, feats1, diffs = {}, {}, {}
156 |
157 | for kk in range(self.L):
158 | feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
159 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
160 |
161 | res = [spatial_average(self.lins[kk].model(diffs[kk])) for kk in range(self.L)]
162 | val = sum(res)
163 | return val
164 |
--------------------------------------------------------------------------------
/torch_fidelity/utils_torch.py:
--------------------------------------------------------------------------------
1 | import os
2 | import tempfile
3 |
4 | import torch
5 |
6 | from torch_fidelity.helpers import vprint
7 |
8 |
9 | def torch_maybe_compile(module, dummy_input, verbose):
10 | out = module
11 | try:
12 | compiled = torch.compile(module)
13 | try:
14 | compiled(dummy_input)
15 | vprint(verbose, "Feature extractor compiled")
16 | setattr(out, "forward_pure", out.forward)
17 | setattr(out, "forward", compiled)
18 | except Exception:
19 | vprint(verbose, "Feature extractor compiled, but failed to run. Falling back to pure torch")
20 | except Exception as e:
21 | vprint(verbose, "Feature extractor compilation failed. Falling back to pure torch")
22 | return out
23 |
24 |
25 | def torch_atomic_save(what, path):
26 | path = os.path.expanduser(path)
27 | path_dir = os.path.dirname(path)
28 | fp = tempfile.NamedTemporaryFile(delete=False, dir=path_dir)
29 | try:
30 | torch.save(what, fp)
31 | fp.close()
32 | os.rename(fp.name, path)
33 | finally:
34 | fp.close()
35 | if os.path.exists(fp.name):
36 | os.remove(fp.name)
37 |
--------------------------------------------------------------------------------
/torch_fidelity/utils_torchvision.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import warnings
3 | from contextlib import redirect_stdout
4 |
5 | import torchvision
6 |
7 | from torch_fidelity.helpers import get_kwarg
8 |
9 |
10 | def torchvision_load_pretrained_vgg16(**kwargs):
11 | verbose = get_kwarg("verbose", kwargs)
12 | with redirect_stdout(sys.stderr), warnings.catch_warnings():
13 | warnings.filterwarnings("ignore", message="The parameter 'pretrained' is deprecated")
14 | warnings.filterwarnings("ignore", message="Arguments other than a weight enum")
15 | warnings.filterwarnings(
16 | "ignore",
17 | message="'torch.load' received a zip file that looks like a TorchScript "
18 | "archive dispatching to 'torch.jit.load'",
19 | )
20 | try:
21 | out = torchvision.models.vgg16(
22 | weights=torchvision.models.VGG16_Weights.IMAGENET1K_V1,
23 | progress=verbose,
24 | )
25 | except Exception:
26 | out = torchvision.models.vgg16(
27 | pretrained=True,
28 | progress=verbose,
29 | )
30 | return out
31 |
--------------------------------------------------------------------------------
/torch_fidelity/version.py:
--------------------------------------------------------------------------------
1 | __version__ = "0.4.0-beta"
2 |
--------------------------------------------------------------------------------
/torch_utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | # empty
9 |
--------------------------------------------------------------------------------
/torch_utils/distributed.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | import os
9 | import torch
10 | from . import training_stats
11 |
12 | #----------------------------------------------------------------------------
13 |
14 | def init():
15 | if 'MASTER_ADDR' not in os.environ:
16 | os.environ['MASTER_ADDR'] = 'localhost'
17 | # if 'MASTER_PORT' not in os.environ:
18 | os.environ['MASTER_PORT'] = '29500'
19 | if 'RANK' not in os.environ:
20 | os.environ['RANK'] = '0'
21 | if 'LOCAL_RANK' not in os.environ:
22 | os.environ['LOCAL_RANK'] = '0'
23 | if 'WORLD_SIZE' not in os.environ:
24 | os.environ['WORLD_SIZE'] = '1'
25 |
26 | backend = 'gloo' if os.name == 'nt' else 'nccl'
27 | torch.distributed.init_process_group(backend=backend, init_method='env://')
28 | torch.cuda.set_device(int(os.environ.get('LOCAL_RANK', '0')))
29 |
30 | sync_device = torch.device('cuda') if get_world_size() > 1 else None
31 | training_stats.init_multiprocessing(rank=get_rank(), sync_device=sync_device)
32 |
33 | #----------------------------------------------------------------------------
34 |
35 | def get_rank():
36 | return torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
37 |
38 | #----------------------------------------------------------------------------
39 |
40 | def get_world_size():
41 | return torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1
42 |
43 | #----------------------------------------------------------------------------
44 |
45 | def should_stop():
46 | return False
47 |
48 | #----------------------------------------------------------------------------
49 |
50 | def update_progress(cur, total):
51 | _ = cur, total
52 |
53 | #----------------------------------------------------------------------------
54 |
55 | def print0(*args, **kwargs):
56 | if get_rank() == 0:
57 | print(*args, **kwargs)
58 |
59 | #----------------------------------------------------------------------------
60 |
--------------------------------------------------------------------------------
/torch_utils/misc.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | import re
9 | import contextlib
10 | import numpy as np
11 | import torch
12 | import warnings
13 | import dnnlib
14 |
15 | #----------------------------------------------------------------------------
16 | # Cached construction of constant tensors. Avoids CPU=>GPU copy when the
17 | # same constant is used multiple times.
18 |
19 | _constant_cache = dict()
20 |
21 | def constant(value, shape=None, dtype=None, device=None, memory_format=None):
22 | value = np.asarray(value)
23 | if shape is not None:
24 | shape = tuple(shape)
25 | if dtype is None:
26 | dtype = torch.get_default_dtype()
27 | if device is None:
28 | device = torch.device('cpu')
29 | if memory_format is None:
30 | memory_format = torch.contiguous_format
31 |
32 | key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
33 | tensor = _constant_cache.get(key, None)
34 | if tensor is None:
35 | tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
36 | if shape is not None:
37 | tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
38 | tensor = tensor.contiguous(memory_format=memory_format)
39 | _constant_cache[key] = tensor
40 | return tensor
41 |
42 | #----------------------------------------------------------------------------
43 | # Replace NaN/Inf with specified numerical values.
44 |
45 | try:
46 | nan_to_num = torch.nan_to_num # 1.8.0a0
47 | except AttributeError:
48 | def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin
49 | assert isinstance(input, torch.Tensor)
50 | if posinf is None:
51 | posinf = torch.finfo(input.dtype).max
52 | if neginf is None:
53 | neginf = torch.finfo(input.dtype).min
54 | assert nan == 0
55 | return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out)
56 |
57 | #----------------------------------------------------------------------------
58 | # Symbolic assert.
59 |
60 | try:
61 | symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access
62 | except AttributeError:
63 | symbolic_assert = torch.Assert # 1.7.0
64 |
65 | #----------------------------------------------------------------------------
66 | # Context manager to temporarily suppress known warnings in torch.jit.trace().
67 | # Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672
68 |
69 | @contextlib.contextmanager
70 | def suppress_tracer_warnings():
71 | flt = ('ignore', None, torch.jit.TracerWarning, None, 0)
72 | warnings.filters.insert(0, flt)
73 | yield
74 | warnings.filters.remove(flt)
75 |
76 | #----------------------------------------------------------------------------
77 | # Assert that the shape of a tensor matches the given list of integers.
78 | # None indicates that the size of a dimension is allowed to vary.
79 | # Performs symbolic assertion when used in torch.jit.trace().
80 |
81 | def assert_shape(tensor, ref_shape):
82 | if tensor.ndim != len(ref_shape):
83 | raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}')
84 | for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
85 | if ref_size is None:
86 | pass
87 | elif isinstance(ref_size, torch.Tensor):
88 | with suppress_tracer_warnings(): # as_tensor results are registered as constants
89 | symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}')
90 | elif isinstance(size, torch.Tensor):
91 | with suppress_tracer_warnings(): # as_tensor results are registered as constants
92 | symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}')
93 | elif size != ref_size:
94 | raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}')
95 |
96 | #----------------------------------------------------------------------------
97 | # Function decorator that calls torch.autograd.profiler.record_function().
98 |
99 | def profiled_function(fn):
100 | def decorator(*args, **kwargs):
101 | with torch.autograd.profiler.record_function(fn.__name__):
102 | return fn(*args, **kwargs)
103 | decorator.__name__ = fn.__name__
104 | return decorator
105 |
106 | #----------------------------------------------------------------------------
107 | # Sampler for torch.utils.data.DataLoader that loops over the dataset
108 | # indefinitely, shuffling items as it goes.
109 |
110 | class InfiniteSampler(torch.utils.data.Sampler):
111 | def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5):
112 | assert len(dataset) > 0
113 | assert num_replicas > 0
114 | assert 0 <= rank < num_replicas
115 | assert 0 <= window_size <= 1
116 | super().__init__(dataset)
117 | self.dataset = dataset
118 | self.rank = rank
119 | self.num_replicas = num_replicas
120 | self.shuffle = shuffle
121 | self.seed = seed
122 | self.window_size = window_size
123 |
124 | def __iter__(self):
125 | order = np.arange(len(self.dataset))
126 | rnd = None
127 | window = 0
128 | if self.shuffle:
129 | rnd = np.random.RandomState(self.seed)
130 | rnd.shuffle(order)
131 | window = int(np.rint(order.size * self.window_size))
132 |
133 | idx = 0
134 | while True:
135 | i = idx % order.size
136 | if idx % self.num_replicas == self.rank:
137 | yield order[i]
138 | if window >= 2:
139 | j = (i - rnd.randint(window)) % order.size
140 | order[i], order[j] = order[j], order[i]
141 | idx += 1
142 |
143 | #----------------------------------------------------------------------------
144 | # Utilities for operating with torch.nn.Module parameters and buffers.
145 |
146 | def params_and_buffers(module):
147 | assert isinstance(module, torch.nn.Module)
148 | return list(module.parameters()) + list(module.buffers())
149 |
150 | def named_params_and_buffers(module):
151 | assert isinstance(module, torch.nn.Module)
152 | return list(module.named_parameters()) + list(module.named_buffers())
153 |
154 | @torch.no_grad()
155 | def copy_params_and_buffers(src_module, dst_module, require_all=False):
156 | assert isinstance(src_module, torch.nn.Module)
157 | assert isinstance(dst_module, torch.nn.Module)
158 | src_tensors = dict(named_params_and_buffers(src_module))
159 | for name, tensor in named_params_and_buffers(dst_module):
160 | assert (name in src_tensors) or (not require_all)
161 | if name in src_tensors:
162 | tensor.copy_(src_tensors[name])
163 |
164 | #----------------------------------------------------------------------------
165 | # Context manager for easily enabling/disabling DistributedDataParallel
166 | # synchronization.
167 |
168 | @contextlib.contextmanager
169 | def ddp_sync(module, sync):
170 | assert isinstance(module, torch.nn.Module)
171 | if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel):
172 | yield
173 | else:
174 | with module.no_sync():
175 | yield
176 |
177 | #----------------------------------------------------------------------------
178 | # Check DistributedDataParallel consistency across processes.
179 |
180 | def check_ddp_consistency(module, ignore_regex=None):
181 | assert isinstance(module, torch.nn.Module)
182 | for name, tensor in named_params_and_buffers(module):
183 | fullname = type(module).__name__ + '.' + name
184 | if ignore_regex is not None and re.fullmatch(ignore_regex, fullname):
185 | continue
186 | tensor = tensor.detach()
187 | if tensor.is_floating_point():
188 | tensor = nan_to_num(tensor)
189 | other = tensor.clone()
190 | torch.distributed.broadcast(tensor=other, src=0)
191 | assert (tensor == other).all(), fullname
192 |
193 | #----------------------------------------------------------------------------
194 | # Print summary table of module hierarchy.
195 |
196 | def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True):
197 | assert isinstance(module, torch.nn.Module)
198 | assert not isinstance(module, torch.jit.ScriptModule)
199 | assert isinstance(inputs, (tuple, list))
200 |
201 | # Register hooks.
202 | entries = []
203 | nesting = [0]
204 | def pre_hook(_mod, _inputs):
205 | nesting[0] += 1
206 | def post_hook(mod, _inputs, outputs):
207 | nesting[0] -= 1
208 | if nesting[0] <= max_nesting:
209 | outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs]
210 | outputs = [t for t in outputs if isinstance(t, torch.Tensor)]
211 | entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs))
212 | hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()]
213 | hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()]
214 |
215 | # Run module.
216 | outputs = module(*inputs)
217 | for hook in hooks:
218 | hook.remove()
219 |
220 | # Identify unique outputs, parameters, and buffers.
221 | tensors_seen = set()
222 | for e in entries:
223 | e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen]
224 | e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen]
225 | e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen]
226 | tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs}
227 |
228 | # Filter out redundant entries.
229 | if skip_redundant:
230 | entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)]
231 |
232 | # Construct table.
233 | rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']]
234 | rows += [['---'] * len(rows[0])]
235 | param_total = 0
236 | buffer_total = 0
237 | submodule_names = {mod: name for name, mod in module.named_modules()}
238 | for e in entries:
239 | name = '' if e.mod is module else submodule_names[e.mod]
240 | param_size = sum(t.numel() for t in e.unique_params)
241 | buffer_size = sum(t.numel() for t in e.unique_buffers)
242 | output_shapes = [str(list(t.shape)) for t in e.outputs]
243 | output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs]
244 | rows += [[
245 | name + (':0' if len(e.outputs) >= 2 else ''),
246 | str(param_size) if param_size else '-',
247 | str(buffer_size) if buffer_size else '-',
248 | (output_shapes + ['-'])[0],
249 | (output_dtypes + ['-'])[0],
250 | ]]
251 | for idx in range(1, len(e.outputs)):
252 | rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]]
253 | param_total += param_size
254 | buffer_total += buffer_size
255 | rows += [['---'] * len(rows[0])]
256 | rows += [['Total', str(param_total), str(buffer_total), '-', '-']]
257 |
258 | # Print table.
259 | widths = [max(len(cell) for cell in column) for column in zip(*rows)]
260 | print()
261 | for row in rows:
262 | print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths)))
263 | print()
264 | return outputs
265 |
266 | #----------------------------------------------------------------------------
267 |
--------------------------------------------------------------------------------
/torch_utils/persistence.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | """Facilities for pickling Python code alongside other data.
9 |
10 | The pickled code is automatically imported into a separate Python module
11 | during unpickling. This way, any previously exported pickles will remain
12 | usable even if the original code is no longer available, or if the current
13 | version of the code is not consistent with what was originally pickled."""
14 |
15 | import sys
16 | import pickle
17 | import io
18 | import inspect
19 | import copy
20 | import uuid
21 | import types
22 | import dnnlib
23 |
24 | #----------------------------------------------------------------------------
25 |
26 | _version = 6 # internal version number
27 | _decorators = set() # {decorator_class, ...}
28 | _import_hooks = [] # [hook_function, ...]
29 | _module_to_src_dict = dict() # {module: src, ...}
30 | _src_to_module_dict = dict() # {src: module, ...}
31 |
32 | #----------------------------------------------------------------------------
33 |
34 | def persistent_class(orig_class):
35 | r"""Class decorator that extends a given class to save its source code
36 | when pickled.
37 |
38 | Example:
39 |
40 | from torch_utils import persistence
41 |
42 | @persistence.persistent_class
43 | class MyNetwork(torch.nn.Module):
44 | def __init__(self, num_inputs, num_outputs):
45 | super().__init__()
46 | self.fc = MyLayer(num_inputs, num_outputs)
47 | ...
48 |
49 | @persistence.persistent_class
50 | class MyLayer(torch.nn.Module):
51 | ...
52 |
53 | When pickled, any instance of `MyNetwork` and `MyLayer` will save its
54 | source code alongside other internal state (e.g., parameters, buffers,
55 | and submodules). This way, any previously exported pickle will remain
56 | usable even if the class definitions have been modified or are no
57 | longer available.
58 |
59 | The decorator saves the source code of the entire Python module
60 | containing the decorated class. It does *not* save the source code of
61 | any imported modules. Thus, the imported modules must be available
62 | during unpickling, also including `torch_utils.persistence` itself.
63 |
64 | It is ok to call functions defined in the same module from the
65 | decorated class. However, if the decorated class depends on other
66 | classes defined in the same module, they must be decorated as well.
67 | This is illustrated in the above example in the case of `MyLayer`.
68 |
69 | It is also possible to employ the decorator just-in-time before
70 | calling the constructor. For example:
71 |
72 | cls = MyLayer
73 | if want_to_make_it_persistent:
74 | cls = persistence.persistent_class(cls)
75 | layer = cls(num_inputs, num_outputs)
76 |
77 | As an additional feature, the decorator also keeps track of the
78 | arguments that were used to construct each instance of the decorated
79 | class. The arguments can be queried via `obj.init_args` and
80 | `obj.init_kwargs`, and they are automatically pickled alongside other
81 | object state. This feature can be disabled on a per-instance basis
82 | by setting `self._record_init_args = False` in the constructor.
83 |
84 | A typical use case is to first unpickle a previous instance of a
85 | persistent class, and then upgrade it to use the latest version of
86 | the source code:
87 |
88 | with open('old_pickle.pkl', 'rb') as f:
89 | old_net = pickle.load(f)
90 | new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs)
91 | misc.copy_params_and_buffers(old_net, new_net, require_all=True)
92 | """
93 | assert isinstance(orig_class, type)
94 | if is_persistent(orig_class):
95 | return orig_class
96 |
97 | assert orig_class.__module__ in sys.modules
98 | orig_module = sys.modules[orig_class.__module__]
99 | orig_module_src = _module_to_src(orig_module)
100 |
101 | class Decorator(orig_class):
102 | _orig_module_src = orig_module_src
103 | _orig_class_name = orig_class.__name__
104 |
105 | def __init__(self, *args, **kwargs):
106 | super().__init__(*args, **kwargs)
107 | record_init_args = getattr(self, '_record_init_args', True)
108 | self._init_args = copy.deepcopy(args) if record_init_args else None
109 | self._init_kwargs = copy.deepcopy(kwargs) if record_init_args else None
110 | assert orig_class.__name__ in orig_module.__dict__
111 | _check_pickleable(self.__reduce__())
112 |
113 | @property
114 | def init_args(self):
115 | assert self._init_args is not None
116 | return copy.deepcopy(self._init_args)
117 |
118 | @property
119 | def init_kwargs(self):
120 | assert self._init_kwargs is not None
121 | return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs))
122 |
123 | def __reduce__(self):
124 | fields = list(super().__reduce__())
125 | fields += [None] * max(3 - len(fields), 0)
126 | if fields[0] is not _reconstruct_persistent_obj:
127 | meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2])
128 | fields[0] = _reconstruct_persistent_obj # reconstruct func
129 | fields[1] = (meta,) # reconstruct args
130 | fields[2] = None # state dict
131 | return tuple(fields)
132 |
133 | Decorator.__name__ = orig_class.__name__
134 | Decorator.__module__ = orig_class.__module__
135 | _decorators.add(Decorator)
136 | return Decorator
137 |
138 | #----------------------------------------------------------------------------
139 |
140 | def is_persistent(obj):
141 | r"""Test whether the given object or class is persistent, i.e.,
142 | whether it will save its source code when pickled.
143 | """
144 | try:
145 | if obj in _decorators:
146 | return True
147 | except TypeError:
148 | pass
149 | return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck
150 |
151 | #----------------------------------------------------------------------------
152 |
153 | def import_hook(hook):
154 | r"""Register an import hook that is called whenever a persistent object
155 | is being unpickled. A typical use case is to patch the pickled source
156 | code to avoid errors and inconsistencies when the API of some imported
157 | module has changed.
158 |
159 | The hook should have the following signature:
160 |
161 | hook(meta) -> modified meta
162 |
163 | `meta` is an instance of `dnnlib.EasyDict` with the following fields:
164 |
165 | type: Type of the persistent object, e.g. `'class'`.
166 | version: Internal version number of `torch_utils.persistence`.
167 | module_src Original source code of the Python module.
168 | class_name: Class name in the original Python module.
169 | state: Internal state of the object.
170 |
171 | Example:
172 |
173 | @persistence.import_hook
174 | def wreck_my_network(meta):
175 | if meta.class_name == 'MyNetwork':
176 | print('MyNetwork is being imported. I will wreck it!')
177 | meta.module_src = meta.module_src.replace("True", "False")
178 | return meta
179 | """
180 | assert callable(hook)
181 | _import_hooks.append(hook)
182 |
183 | #----------------------------------------------------------------------------
184 |
185 | def _reconstruct_persistent_obj(meta):
186 | r"""Hook that is called internally by the `pickle` module to unpickle
187 | a persistent object.
188 | """
189 | meta = dnnlib.EasyDict(meta)
190 | meta.state = dnnlib.EasyDict(meta.state)
191 | for hook in _import_hooks:
192 | meta = hook(meta)
193 | assert meta is not None
194 |
195 | assert meta.version == _version
196 | module = _src_to_module(meta.module_src)
197 |
198 | assert meta.type == 'class'
199 | orig_class = module.__dict__[meta.class_name]
200 | decorator_class = persistent_class(orig_class)
201 | obj = decorator_class.__new__(decorator_class)
202 |
203 | setstate = getattr(obj, '__setstate__', None)
204 | if callable(setstate):
205 | setstate(meta.state) # pylint: disable=not-callable
206 | else:
207 | obj.__dict__.update(meta.state)
208 | return obj
209 |
210 | #----------------------------------------------------------------------------
211 |
212 | def _module_to_src(module):
213 | r"""Query the source code of a given Python module.
214 | """
215 | src = _module_to_src_dict.get(module, None)
216 | if src is None:
217 | src = inspect.getsource(module)
218 | _module_to_src_dict[module] = src
219 | _src_to_module_dict[src] = module
220 | return src
221 |
222 | def _src_to_module(src):
223 | r"""Get or create a Python module for the given source code.
224 | """
225 | module = _src_to_module_dict.get(src, None)
226 | if module is None:
227 | module_name = "_imported_module_" + uuid.uuid4().hex
228 | module = types.ModuleType(module_name)
229 | sys.modules[module_name] = module
230 | _module_to_src_dict[module] = src
231 | _src_to_module_dict[src] = module
232 | exec(src, module.__dict__) # pylint: disable=exec-used
233 | return module
234 |
235 | #----------------------------------------------------------------------------
236 |
237 | def _check_pickleable(obj):
238 | r"""Check that the given object is pickleable, raising an exception if
239 | it is not. This function is expected to be considerably more efficient
240 | than actually pickling the object.
241 | """
242 | def recurse(obj):
243 | if isinstance(obj, (list, tuple, set)):
244 | return [recurse(x) for x in obj]
245 | if isinstance(obj, dict):
246 | return [[recurse(x), recurse(y)] for x, y in obj.items()]
247 | if isinstance(obj, (str, int, float, bool, bytes, bytearray)):
248 | return None # Python primitive types are pickleable.
249 | if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor', 'torch.nn.parameter.Parameter']:
250 | return None # NumPy arrays and PyTorch tensors are pickleable.
251 | if is_persistent(obj):
252 | return None # Persistent objects are pickleable, by virtue of the constructor check.
253 | return obj
254 | with io.BytesIO() as f:
255 | pickle.dump(recurse(obj), f)
256 |
257 | #----------------------------------------------------------------------------
258 |
--------------------------------------------------------------------------------
/torch_utils/torch_dct.py:
--------------------------------------------------------------------------------
1 | """Taken from https://github.com/zh217/torch-dct/blob/master/torch_dct/_dct.py
2 | Some modifications have been made to work with newer versions of Pytorch"""
3 |
4 | import numpy as np
5 | import torch
6 | import torch.nn as nn
7 | import time
8 |
9 | def re_arrange(img, img_size):
10 | b, c, h, w = img.shape
11 | x = img.reshape(b, c, h//img_size, img_size, w//img_size, img_size).permute(0, 2, 4, 1, 3, 5).reshape(-1, c, img_size, img_size)
12 | return x
13 |
14 |
15 | def recover(img, origin_img_size):
16 | img_size = img.shape[-1]
17 | x = img.reshape(-1, origin_img_size//img_size, origin_img_size//img_size, 3, img_size, img_size).permute(0, 3, 1, 4, 2, 5).reshape(-1, 3, origin_img_size, origin_img_size)
18 | return x
19 |
20 |
21 | def dct(x, norm=None):
22 | """
23 | Discrete Cosine Transform, Type II (a.k.a. the DCT)
24 | For the meaning of the parameter `norm`, see:
25 | https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
26 | :param x: the input signal
27 | :param norm: the normalization, None or 'ortho'
28 | :return: the DCT-II of the signal over the last dimension
29 | """
30 | x_shape = x.shape
31 | N = x_shape[-1]
32 | x = x.contiguous().view(-1, N)
33 |
34 | v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1)
35 |
36 | #Vc = torch.fft.rfft(v, 1)
37 | Vc = torch.view_as_real(torch.fft.fft(v, dim=1))
38 |
39 | k = - torch.arange(N, dtype=x.dtype,
40 | device=x.device)[None, :] * np.pi / (2 * N)
41 | W_r = torch.cos(k)
42 | W_i = torch.sin(k)
43 |
44 | V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i
45 |
46 | if norm == 'ortho':
47 | V[:, 0] /= np.sqrt(N) * 2
48 | V[:, 1:] /= np.sqrt(N / 2) * 2
49 |
50 | V = 2 * V.view(*x_shape)
51 |
52 | return V
53 |
54 |
55 | def idct(X, norm=None):
56 | """
57 | The inverse to DCT-II, which is a scaled Discrete Cosine Transform, Type III
58 | Our definition of idct is that idct(dct(x)) == x
59 | For the meaning of the parameter `norm`, see:
60 | https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
61 | :param X: the input signal
62 | :param norm: the normalization, None or 'ortho'
63 | :return: the inverse DCT-II of the signal over the last dimension
64 | """
65 |
66 | x_shape = X.shape
67 | N = x_shape[-1]
68 |
69 | X_v = X.contiguous().view(-1, x_shape[-1]) / 2
70 |
71 | if norm == 'ortho':
72 | X_v[:, 0] *= np.sqrt(N) * 2
73 | X_v[:, 1:] *= np.sqrt(N / 2) * 2
74 |
75 | k = torch.arange(x_shape[-1], dtype=X.dtype,
76 | device=X.device)[None, :] * np.pi / (2 * N)
77 | W_r = torch.cos(k)
78 | W_i = torch.sin(k)
79 |
80 | V_t_r = X_v
81 | V_t_i = torch.cat([X_v[:, :1] * 0, -X_v.flip([1])[:, :-1]], dim=1)
82 |
83 | V_r = V_t_r * W_r - V_t_i * W_i
84 | V_i = V_t_r * W_i + V_t_i * W_r
85 |
86 | V = torch.cat([V_r.unsqueeze(2), V_i.unsqueeze(2)], dim=2)
87 |
88 | #v = torch.fft.irfft(V, 1)
89 | v = torch.fft.irfft(torch.view_as_complex(V), n=V.shape[1], dim=1)
90 | x = v.new_zeros(v.shape)
91 | x[:, ::2] += v[:, :N - (N // 2)]
92 | x[:, 1::2] += v.flip([1])[:, :N // 2]
93 |
94 | return x.view(*x_shape)
95 |
96 |
97 | def dct_2d(x, size, norm=None):
98 | """
99 | 2-dimentional Discrete Cosine Transform, Type II (a.k.a. the DCT)
100 | For the meaning of the parameter `norm`, see:
101 | https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
102 | :param x: the input signal
103 | :param norm: the normalization, None or 'ortho'
104 | :return: the DCT-II of the signal over the last 2 dimensions
105 | """
106 | start = time.time()
107 | origin_size = x.shape[-1]
108 | if origin_size > size:
109 | x = re_arrange(x, size)
110 | # print('time1:', time.time()-start)
111 | X1 = dct(x, norm=norm)
112 | X2 = dct(X1.transpose(-1, -2), norm=norm)
113 | X2 = X2.transpose(-1, -2)
114 | # print('time2:', time.time()-start)
115 | if origin_size > size:
116 | X2 = recover(X2, origin_size)
117 | # print('time3:', time.time()-start)
118 | return X2
119 |
120 | def idct_2d(X, size, norm=None):
121 | """
122 | The inverse to 2D DCT-II, which is a scaled Discrete Cosine Transform, Type III
123 | Our definition of idct is that idct_2d(dct_2d(x)) == x
124 | For the meaning of the parameter `norm`, see:
125 | https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
126 | :param X: the input signal
127 | :param norm: the normalization, None or 'ortho'
128 | :return: the DCT-II of the signal over the last 2 dimensions
129 | """
130 | start = time.time()
131 | origin_size = X.shape[-1]
132 | if origin_size > size:
133 | X = re_arrange(X, size)
134 |
135 | x1 = idct(X, norm=norm)
136 | x2 = idct(x1.transpose(-1, -2), norm=norm)
137 | x2 = x2.transpose(-1, -2)
138 |
139 | if origin_size > size:
140 | x2 = recover(x2, origin_size)
141 | # print('time4:', time.time()-start)
142 | return x2
143 |
144 | def dct_3d(x, norm=None):
145 | """
146 | 3-dimentional Discrete Cosine Transform, Type II (a.k.a. the DCT)
147 | For the meaning of the parameter `norm`, see:
148 | https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
149 | :param x: the input signal
150 | :param norm: the normalization, None or 'ortho'
151 | :return: the DCT-II of the signal over the last 3 dimensions
152 | """
153 | X1 = dct(x, norm=norm)
154 | X2 = dct(X1.transpose(-1, -2), norm=norm)
155 | X3 = dct(X2.transpose(-1, -3), norm=norm)
156 | return X3.transpose(-1, -3).transpose(-1, -2)
157 |
158 |
159 | def idct_3d(X, norm=None):
160 | """
161 | The inverse to 3D DCT-II, which is a scaled Discrete Cosine Transform, Type III
162 | Our definition of idct is that idct_3d(dct_3d(x)) == x
163 | For the meaning of the parameter `norm`, see:
164 | https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
165 | :param X: the input signal
166 | :param norm: the normalization, None or 'ortho'
167 | :return: the DCT-II of the signal over the last 3 dimensions
168 | """
169 | x1 = idct(X, norm=norm)
170 | x2 = idct(x1.transpose(-1, -2), norm=norm)
171 | x3 = idct(x2.transpose(-1, -3), norm=norm)
172 | return x3.transpose(-1, -3).transpose(-1, -2)
173 |
174 |
175 | class LinearDCT(nn.Linear):
176 | """Implement any DCT as a linear layer; in practice this executes around
177 | 50x faster on GPU. Unfortunately, the DCT matrix is stored, which will
178 | increase memory usage.
179 | :param in_features: size of expected input
180 | :param type: which dct function in this file to use"""
181 |
182 | def __init__(self, in_features, type, norm=None, bias=False):
183 | self.type = type
184 | self.N = in_features
185 | self.norm = norm
186 | super(LinearDCT, self).__init__(in_features, in_features, bias=bias)
187 |
188 | def reset_parameters(self):
189 | # initialise using dct function
190 | I = torch.eye(self.N)
191 | if self.type == 'dct':
192 | self.weight.data = dct(I, norm=self.norm).data.t()
193 | elif self.type == 'idct':
194 | self.weight.data = idct(I, norm=self.norm).data.t()
195 | self.weight.requires_grad = False # don't learn this!
196 |
197 |
198 | def apply_linear_2d(x, linear_layer):
199 | """Can be used with a LinearDCT layer to do a 2D DCT.
200 | :param x: the input signal
201 | :param linear_layer: any PyTorch Linear layer
202 | :return: result of linear layer applied to last 2 dimensions
203 | """
204 | X1 = linear_layer(x)
205 | X2 = linear_layer(X1.transpose(-1, -2))
206 | return X2.transpose(-1, -2)
207 |
208 |
209 | def apply_linear_3d(x, linear_layer):
210 | """Can be used with a LinearDCT layer to do a 3D DCT.
211 | :param x: the input signal
212 | :param linear_layer: any PyTorch Linear layer
213 | :return: result of linear layer applied to last 3 dimensions
214 | """
215 | X1 = linear_layer(x)
216 | X2 = linear_layer(X1.transpose(-1, -2))
217 | X3 = linear_layer(X2.transpose(-1, -3))
218 | return X3.transpose(-1, -3).transpose(-1, -2)
219 |
220 |
221 | if __name__ == '__main__':
222 | x = torch.Tensor(1000, 4096)
223 | x.normal_(0, 1)
224 | linear_dct = LinearDCT(4096, 'dct')
225 | error = torch.abs(dct(x) - linear_dct(x))
226 | assert error.max() < 1e-3, (error, error.max())
227 | linear_idct = LinearDCT(4096, 'idct')
228 | error = torch.abs(idct(x) - linear_idct(x))
229 | assert error.max() < 1e-3, (error, error.max())
230 |
--------------------------------------------------------------------------------
/torch_utils/training_stats.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | """Facilities for reporting and collecting training statistics across
9 | multiple processes and devices. The interface is designed to minimize
10 | synchronization overhead as well as the amount of boilerplate in user
11 | code."""
12 |
13 | import re
14 | import numpy as np
15 | import torch
16 | import dnnlib
17 |
18 | from . import misc
19 |
20 | #----------------------------------------------------------------------------
21 |
22 | _num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares]
23 | _reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction.
24 | _counter_dtype = torch.float64 # Data type to use for the internal counters.
25 | _rank = 0 # Rank of the current process.
26 | _sync_device = None # Device to use for multiprocess communication. None = single-process.
27 | _sync_called = False # Has _sync() been called yet?
28 | _counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor
29 | _cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor
30 |
31 | #----------------------------------------------------------------------------
32 |
33 | def init_multiprocessing(rank, sync_device):
34 | r"""Initializes `torch_utils.training_stats` for collecting statistics
35 | across multiple processes.
36 |
37 | This function must be called after
38 | `torch.distributed.init_process_group()` and before `Collector.update()`.
39 | The call is not necessary if multi-process collection is not needed.
40 |
41 | Args:
42 | rank: Rank of the current process.
43 | sync_device: PyTorch device to use for inter-process
44 | communication, or None to disable multi-process
45 | collection. Typically `torch.device('cuda', rank)`.
46 | """
47 | global _rank, _sync_device
48 | assert not _sync_called
49 | _rank = rank
50 | _sync_device = sync_device
51 |
52 | #----------------------------------------------------------------------------
53 |
54 | @misc.profiled_function
55 | def report(name, value):
56 | r"""Broadcasts the given set of scalars to all interested instances of
57 | `Collector`, across device and process boundaries.
58 |
59 | This function is expected to be extremely cheap and can be safely
60 | called from anywhere in the training loop, loss function, or inside a
61 | `torch.nn.Module`.
62 |
63 | Warning: The current implementation expects the set of unique names to
64 | be consistent across processes. Please make sure that `report()` is
65 | called at least once for each unique name by each process, and in the
66 | same order. If a given process has no scalars to broadcast, it can do
67 | `report(name, [])` (empty list).
68 |
69 | Args:
70 | name: Arbitrary string specifying the name of the statistic.
71 | Averages are accumulated separately for each unique name.
72 | value: Arbitrary set of scalars. Can be a list, tuple,
73 | NumPy array, PyTorch tensor, or Python scalar.
74 |
75 | Returns:
76 | The same `value` that was passed in.
77 | """
78 | if name not in _counters:
79 | _counters[name] = dict()
80 |
81 | elems = torch.as_tensor(value)
82 | if elems.numel() == 0:
83 | return value
84 |
85 | elems = elems.detach().flatten().to(_reduce_dtype)
86 | moments = torch.stack([
87 | torch.ones_like(elems).sum(),
88 | elems.sum(),
89 | elems.square().sum(),
90 | ])
91 | assert moments.ndim == 1 and moments.shape[0] == _num_moments
92 | moments = moments.to(_counter_dtype)
93 |
94 | device = moments.device
95 | if device not in _counters[name]:
96 | _counters[name][device] = torch.zeros_like(moments)
97 | _counters[name][device].add_(moments)
98 | return value
99 |
100 | #----------------------------------------------------------------------------
101 |
102 | def report0(name, value):
103 | r"""Broadcasts the given set of scalars by the first process (`rank = 0`),
104 | but ignores any scalars provided by the other processes.
105 | See `report()` for further details.
106 | """
107 | report(name, value if _rank == 0 else [])
108 | return value
109 |
110 | #----------------------------------------------------------------------------
111 |
112 | class Collector:
113 | r"""Collects the scalars broadcasted by `report()` and `report0()` and
114 | computes their long-term averages (mean and standard deviation) over
115 | user-defined periods of time.
116 |
117 | The averages are first collected into internal counters that are not
118 | directly visible to the user. They are then copied to the user-visible
119 | state as a result of calling `update()` and can then be queried using
120 | `mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the
121 | internal counters for the next round, so that the user-visible state
122 | effectively reflects averages collected between the last two calls to
123 | `update()`.
124 |
125 | Args:
126 | regex: Regular expression defining which statistics to
127 | collect. The default is to collect everything.
128 | keep_previous: Whether to retain the previous averages if no
129 | scalars were collected on a given round
130 | (default: True).
131 | """
132 | def __init__(self, regex='.*', keep_previous=True):
133 | self._regex = re.compile(regex)
134 | self._keep_previous = keep_previous
135 | self._cumulative = dict()
136 | self._moments = dict()
137 | self.update()
138 | self._moments.clear()
139 |
140 | def names(self):
141 | r"""Returns the names of all statistics broadcasted so far that
142 | match the regular expression specified at construction time.
143 | """
144 | return [name for name in _counters if self._regex.fullmatch(name)]
145 |
146 | def update(self):
147 | r"""Copies current values of the internal counters to the
148 | user-visible state and resets them for the next round.
149 |
150 | If `keep_previous=True` was specified at construction time, the
151 | operation is skipped for statistics that have received no scalars
152 | since the last update, retaining their previous averages.
153 |
154 | This method performs a number of GPU-to-CPU transfers and one
155 | `torch.distributed.all_reduce()`. It is intended to be called
156 | periodically in the main training loop, typically once every
157 | N training steps.
158 | """
159 | if not self._keep_previous:
160 | self._moments.clear()
161 | for name, cumulative in _sync(self.names()):
162 | if name not in self._cumulative:
163 | self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
164 | delta = cumulative - self._cumulative[name]
165 | self._cumulative[name].copy_(cumulative)
166 | if float(delta[0]) != 0:
167 | self._moments[name] = delta
168 |
169 | def _get_delta(self, name):
170 | r"""Returns the raw moments that were accumulated for the given
171 | statistic between the last two calls to `update()`, or zero if
172 | no scalars were collected.
173 | """
174 | assert self._regex.fullmatch(name)
175 | if name not in self._moments:
176 | self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
177 | return self._moments[name]
178 |
179 | def num(self, name):
180 | r"""Returns the number of scalars that were accumulated for the given
181 | statistic between the last two calls to `update()`, or zero if
182 | no scalars were collected.
183 | """
184 | delta = self._get_delta(name)
185 | return int(delta[0])
186 |
187 | def mean(self, name):
188 | r"""Returns the mean of the scalars that were accumulated for the
189 | given statistic between the last two calls to `update()`, or NaN if
190 | no scalars were collected.
191 | """
192 | delta = self._get_delta(name)
193 | if int(delta[0]) == 0:
194 | return float('nan')
195 | return float(delta[1] / delta[0])
196 |
197 | def std(self, name):
198 | r"""Returns the standard deviation of the scalars that were
199 | accumulated for the given statistic between the last two calls to
200 | `update()`, or NaN if no scalars were collected.
201 | """
202 | delta = self._get_delta(name)
203 | if int(delta[0]) == 0 or not np.isfinite(float(delta[1])):
204 | return float('nan')
205 | if int(delta[0]) == 1:
206 | return float(0)
207 | mean = float(delta[1] / delta[0])
208 | raw_var = float(delta[2] / delta[0])
209 | return np.sqrt(max(raw_var - np.square(mean), 0))
210 |
211 | def as_dict(self):
212 | r"""Returns the averages accumulated between the last two calls to
213 | `update()` as an `dnnlib.EasyDict`. The contents are as follows:
214 |
215 | dnnlib.EasyDict(
216 | NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT),
217 | ...
218 | )
219 | """
220 | stats = dnnlib.EasyDict()
221 | for name in self.names():
222 | stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name))
223 | return stats
224 |
225 | def __getitem__(self, name):
226 | r"""Convenience getter.
227 | `collector[name]` is a synonym for `collector.mean(name)`.
228 | """
229 | return self.mean(name)
230 |
231 | #----------------------------------------------------------------------------
232 |
233 | def _sync(names):
234 | r"""Synchronize the global cumulative counters across devices and
235 | processes. Called internally by `Collector.update()`.
236 | """
237 | if len(names) == 0:
238 | return []
239 | global _sync_called
240 | _sync_called = True
241 |
242 | # Collect deltas within current rank.
243 | deltas = []
244 | device = _sync_device if _sync_device is not None else torch.device('cpu')
245 | for name in names:
246 | delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device)
247 | for counter in _counters[name].values():
248 | delta.add_(counter.to(device))
249 | counter.copy_(torch.zeros_like(counter))
250 | deltas.append(delta)
251 | deltas = torch.stack(deltas)
252 |
253 | # Sum deltas across ranks.
254 | if _sync_device is not None:
255 | torch.distributed.all_reduce(deltas)
256 |
257 | # Update cumulative values.
258 | deltas = deltas.cpu()
259 | for idx, name in enumerate(names):
260 | if name not in _cumulative:
261 | _cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
262 | _cumulative[name].add_(deltas[idx])
263 |
264 | # Return name-value pairs.
265 | return [(name, _cumulative[name]) for name in names]
266 |
267 | #----------------------------------------------------------------------------
268 | # Convenience.
269 |
270 | default_collector = Collector()
271 |
272 | #----------------------------------------------------------------------------
273 |
--------------------------------------------------------------------------------