├── .env ├── .gitattributes ├── .gitignore ├── README.md ├── __init__.py ├── adversarial ├── embedding.py ├── feature_extractors │ ├── __init__.py │ ├── base_encoder.py │ ├── clip.py │ ├── kl_vae_hg.py │ ├── kl_vae_ldm.py │ └── resnet18.py ├── surrogate.py ├── surrogate_models │ ├── adv_cls_real_wm_stable_sig.pth │ ├── adv_cls_real_wm_stegastamp.pth │ ├── adv_cls_real_wm_tree_ring.pth │ ├── adv_cls_unwm_wm_stable_sig.pth │ ├── adv_cls_unwm_wm_stegastamp.pth │ ├── adv_cls_unwm_wm_tree_ring.pth │ ├── adv_cls_wm1_wm2_stable_sig.pth │ ├── adv_cls_wm1_wm2_stegastamp.pth │ └── adv_cls_wm1_wm2_tree_ring.pth └── train.py ├── app.py ├── cli.py ├── decoders ├── stable_signature.onnx └── stega_stamp.onnx ├── dev ├── __init__.py ├── aggregate.py ├── constants.py ├── eval.py ├── find.py ├── io.py ├── parse.py └── plot.py ├── distortions ├── __init__.py └── distortions.py ├── guided_diffusion ├── __init__.py ├── dist_util.py ├── fp16_util.py ├── gaussian_diffusion.py ├── generate.py ├── image_datasets.py ├── logger.py ├── losses.py ├── model_utils.py ├── nn.py ├── resample.py ├── respace.py ├── script_util.py ├── train_util.py └── unet.py ├── ldm ├── data │ ├── __init__.py │ ├── base.py │ ├── imagenet.py │ └── lsun.py ├── lr_scheduler.py ├── modules │ ├── attention.py │ ├── diffusionmodules │ │ ├── __init__.py │ │ ├── model.py │ │ ├── openaimodel.py │ │ └── util.py │ ├── distributions │ │ ├── __init__.py │ │ └── distributions.py │ ├── ema.py │ ├── encoders │ │ ├── __init__.py │ │ └── modules.py │ ├── losses │ │ ├── __init__.py │ │ ├── contperceptual.py │ │ └── vqperceptual.py │ └── x_transformer.py └── util.py ├── metrics ├── __init__.py ├── aesthetics.py ├── aesthetics_scorer │ ├── __init__.py │ └── model.py ├── clean_fid │ ├── __init__.py │ ├── clip_features.py │ ├── downloads_helper.py │ ├── features.py │ ├── fid.py │ ├── inception_pytorch.py │ ├── inception_torchscript.py │ ├── leaderboard.py │ ├── resize.py │ ├── utils.py │ └── wrappers.py ├── clip.py ├── distributional.py ├── image.py ├── lpips │ ├── __init__.py │ ├── lpips.py │ ├── pretrained_networks.py │ ├── trainer.py │ ├── utils.py │ └── weights │ │ ├── v0.0 │ │ ├── alex.pth │ │ ├── squeeze.pth │ │ └── vgg.pth │ │ └── v0.1 │ │ ├── alex.pth │ │ ├── squeeze.pth │ │ └── vgg.pth ├── metrics │ ├── __init__.py │ ├── aesthetics.py │ ├── aesthetics_scorer │ │ ├── __init__.py │ │ ├── model.py │ │ └── weights │ │ │ ├── aesthetics_scorer_artifacts_openclip_vit_bigg_14.config │ │ │ ├── aesthetics_scorer_artifacts_openclip_vit_bigg_14.pth │ │ │ ├── aesthetics_scorer_artifacts_openclip_vit_h_14.config │ │ │ ├── aesthetics_scorer_artifacts_openclip_vit_h_14.pth │ │ │ ├── aesthetics_scorer_artifacts_openclip_vit_l_14.config │ │ │ ├── aesthetics_scorer_artifacts_openclip_vit_l_14.pth │ │ │ ├── aesthetics_scorer_rating_openclip_vit_bigg_14.config │ │ │ ├── aesthetics_scorer_rating_openclip_vit_bigg_14.pth │ │ │ ├── aesthetics_scorer_rating_openclip_vit_h_14.config │ │ │ ├── aesthetics_scorer_rating_openclip_vit_h_14.pth │ │ │ ├── aesthetics_scorer_rating_openclip_vit_l_14.config │ │ │ └── aesthetics_scorer_rating_openclip_vit_l_14.pth │ ├── clean_fid │ │ ├── __init__.py │ │ ├── clip_features.py │ │ ├── downloads_helper.py │ │ ├── features.py │ │ ├── fid.py │ │ ├── inception_pytorch.py │ │ ├── inception_torchscript.py │ │ ├── leaderboard.py │ │ ├── resize.py │ │ ├── utils.py │ │ └── wrappers.py │ ├── clip.py │ ├── distributional.py │ ├── image.py │ ├── lpips │ │ ├── __init__.py │ │ ├── lpips.py │ │ ├── pretrained_networks.py │ │ ├── trainer.py │ │ ├── utils.py │ │ └── weights │ │ │ ├── v0.0 │ │ │ ├── alex.pth │ │ │ ├── squeeze.pth │ │ │ └── vgg.pth │ │ │ └── v0.1 │ │ │ ├── alex.pth │ │ │ ├── squeeze.pth │ │ │ └── vgg.pth │ ├── perceptual.py │ ├── prompt.py │ └── watson │ │ ├── __init__.py │ │ ├── color_wrapper.py │ │ ├── dct2d.py │ │ ├── deep_loss.py │ │ ├── loss_provider.py │ │ ├── rfft2d.py │ │ ├── shift_wrapper.py │ │ ├── ssim.py │ │ ├── watson.py │ │ ├── watson_fft.py │ │ ├── watson_vgg.py │ │ └── weights │ │ ├── gray_adaptive_trial0.pth │ │ ├── gray_pnet_lin_squeeze_trial0.pth │ │ ├── gray_pnet_lin_vgg_trial0.pth │ │ ├── gray_watson_dct_trial0.pth │ │ ├── gray_watson_fft_trial0.pth │ │ ├── gray_watson_vgg_trial0.pth │ │ ├── rgb_adaptive_trial0.pth │ │ ├── rgb_pnet_lin_squeeze_trial0.pth │ │ ├── rgb_pnet_lin_vgg_trial0.pth │ │ ├── rgb_watson_dct_trial0.pth │ │ ├── rgb_watson_fft_trial0.pth │ │ └── rgb_watson_vgg_trial0.pth ├── perceptual.py ├── prompt.py └── workflow_a_small.png ├── regeneration ├── __init__.py └── regen.py ├── requirements_all.txt ├── requirements_attack.txt ├── requirements_cli.txt ├── requirements_space.txt ├── scripts ├── chmod.py ├── decode.py ├── metric.py ├── reverse.py └── status.py ├── setup.py ├── shell_scripts └── install_dependencies.sh ├── static └── images │ ├── 2d.jpg │ ├── 2d_ident.jpg │ ├── 2d_tree_ident.jpg │ ├── 2x_regen-100.jpg │ ├── 2x_regen-20.jpg │ ├── 4x_regen-10.jpg │ ├── 4x_regen-50.jpg │ ├── 4x_regen_kl_vae-16.jpg │ ├── 4x_regen_kl_vae-4.jpg │ ├── adv_cls_wm1_wm2_0.01_50_warm-2-tree_ring.jpg │ ├── adv_cls_wm1_wm2_0.01_50_warm-8-tree_ring.jpg │ ├── adv_emb.jpg │ ├── adv_emb_clip_untg_alphaRatio_0.05_step_200-16-tree_ring.jpg │ ├── adv_emb_clip_untg_alphaRatio_0.05_step_200-2-tree_ring.jpg │ ├── adv_emb_clip_untg_alphaRatio_0.05_step_200-8-tree_ring.jpg │ ├── adv_emb_coco.jpg │ ├── adv_emb_diff.jpg │ ├── adv_emb_same_vae_untg-2.jpg │ ├── adv_emb_same_vae_untg-8.jpg │ ├── adv_spoof.jpg │ ├── adv_su.jpg │ ├── all_fig_coco_1.jpg │ ├── all_fig_coco_2.jpg │ ├── all_fig_dalle_1.jpg │ ├── all_fig_dalle_2.jpg │ ├── all_fig_diff_1.jpg │ ├── all_fig_diff_2.jpg │ ├── bench_watermarks_detect 2.jpg │ ├── bench_watermarks_detect.jpg │ ├── bench_watermarks_ident 2.jpg │ ├── bench_watermarks_ident.jpg │ ├── carousel1.jpg │ ├── carousel2.jpg │ ├── carousel3.jpg │ ├── carousel4.jpg │ ├── dataset_dalle3_examples.jpg │ ├── dataset_dalle3_wordcloud.jpg │ ├── dataset_diffusiondb_examples.jpg │ ├── dataset_diffusiondb_wordcloud.jpg │ ├── dataset_mscoco_examples.jpg │ ├── dataset_mscoco_wordcloud.jpg │ ├── dist_com1.jpg │ ├── dist_com2.jpg │ ├── distcom-deg-0.15.jpg │ ├── distcom-geo-0.15.jpg │ ├── distcom-photo-0.15.jpg │ ├── distortion_combo_all-0.04.jpg │ ├── distortion_combo_all-0.2.jpg │ ├── example_1.jpg │ ├── example_2.jpg │ ├── example_3.jpg │ ├── example_4.jpg │ ├── example_5.jpg │ ├── favicon.ico │ ├── illu_adv_real_wm.jpg │ ├── illu_adv_unwm_wm.jpg │ ├── illu_adv_wm1_wm2.jpg │ ├── legend.jpg │ ├── no_attack.jpg │ ├── no_watermark.jpg │ ├── problem.gif │ ├── quality_metric_cdf_normalize_range.jpg │ ├── radar_iden_100.jpg │ ├── radar_iden_1000.jpg │ ├── radar_iden_1000000.jpg │ ├── radar_plot.jpg │ ├── regen-200.jpg │ ├── regen-40.jpg │ ├── regen_coco_clip.jpg │ ├── regen_coco_psnr.jpg │ ├── regen_diff_clip.jpg │ ├── regen_diff_psnr.jpg │ ├── regen_vae.jpg │ ├── spec_adv_unwm_wm.jpg │ ├── spec_adv_wm1_wm2.jpg │ ├── tree-ring-heatmap-old.jpg │ ├── unattacked-tree_ring.jpg │ ├── violin.jpg │ ├── waves.jpg │ ├── waves_small.png │ ├── workflow.jpg │ ├── workflow_a.gif │ ├── workflow_a.jpg │ ├── workflow_b.gif │ └── workflow_b.jpg ├── tree_ring ├── __init__.py ├── data_utils.py ├── guided_diffusion.py ├── io_utils.py ├── optim_utils.py └── stable_diffusion.py └── utils ├── __init__.py ├── data_utils.py ├── exp_utils.py ├── image_utils.py ├── io_utils.py ├── plot_utils.py └── vis_utils.py /.env: -------------------------------------------------------------------------------- 1 | # CML folders 2 | DATA_DIR=/path/to/datasets 3 | MODEL_DIR=/path/to/models 4 | # Cache folders 5 | TORCH_HOME=/tmp/torch/ 6 | HF_HOME=/tmp/huggingface/ 7 | # HF space 8 | # CUDA and CUDNN paths for ONNX 9 | LD_LIBRARY_PATH=/opt/common/cuda/cuda-11.8.0/lib64:/opt/common/cudnn/cudnn-11.x-8.8.0.121/lib64:${LD_LIBRARY_PATH} 10 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | *.onnx filter=lfs diff=lfs merge=lfs -text 2 | *.pth filter=lfs diff=lfs merge=lfs -text 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.bundle.* 2 | lib/ 3 | node_modules/ 4 | *.egg-info/ 5 | .ipynb_checkpoints 6 | *.tsbuildinfo 7 | 8 | # Created by https://www.gitignore.io/api/python 9 | # Edit at https://www.gitignore.io/?templates=python 10 | 11 | ### Python ### 12 | # Byte-compiled / optimized / DLL files 13 | __pycache__/ 14 | *.py[cod] 15 | *$py.class 16 | 17 | # C extensions 18 | *.so 19 | 20 | # Distribution / packaging 21 | .Python 22 | build/ 23 | develop-eggs/ 24 | dist/ 25 | downloads/ 26 | eggs/ 27 | .eggs/ 28 | lib/ 29 | lib64/ 30 | parts/ 31 | sdist/ 32 | var/ 33 | wheels/ 34 | pip-wheel-metadata/ 35 | share/python-wheels/ 36 | .installed.cfg 37 | *.egg 38 | MANIFEST 39 | 40 | # PyInstaller 41 | # Usually these files are written by a python script from a template 42 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 43 | *.manifest 44 | *.spec 45 | 46 | # Installer logs 47 | pip-log.txt 48 | pip-delete-this-directory.txt 49 | 50 | # Unit test / coverage reports 51 | htmlcov/ 52 | .tox/ 53 | .nox/ 54 | .coverage 55 | .coverage.* 56 | .cache 57 | nosetests.xml 58 | coverage.xml 59 | *.cover 60 | .hypothesis/ 61 | .pytest_cache/ 62 | 63 | # Translations 64 | *.mo 65 | *.pot 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # pyenv 77 | .python-version 78 | 79 | # celery beat schedule file 80 | celerybeat-schedule 81 | 82 | # SageMath parsed files 83 | *.sage.py 84 | 85 | # Spyder project settings 86 | .spyderproject 87 | .spyproject 88 | 89 | # Rope project settings 90 | .ropeproject 91 | 92 | # Mr Developer 93 | .mr.developer.cfg 94 | .project 95 | .pydevproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | .dmypy.json 103 | dmypy.json 104 | 105 | # Pyre type checker 106 | .pyre/ 107 | 108 | # OS X stuff 109 | *.DS_Store 110 | 111 | # End of https://www.gitignore.io/api/python 112 | 113 | _temp_extension 114 | junit.xml 115 | [uU]ntitled* 116 | notebook/static/* 117 | !notebook/static/favicons 118 | notebook/labextension 119 | notebook/schemas 120 | docs/source/changelog.md 121 | docs/source/contributing.md 122 | 123 | # playwright 124 | ui-tests/test-results 125 | ui-tests/playwright-report 126 | 127 | # VSCode 128 | .vscode 129 | 130 | # RTC 131 | .jupyter_ystore.db 132 | 133 | # yarn >=2.x local files 134 | .yarn/* 135 | .pnp.* 136 | ui-tests/.yarn/* 137 | ui-tests/.pnp.* 138 | 139 | # virtual env 140 | venv 141 | .venv 142 | 143 | # datasets 144 | datasets 145 | 146 | # models 147 | models 148 | 149 | # wandb 150 | wandb -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/__init__.py -------------------------------------------------------------------------------- /adversarial/feature_extractors/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet18 import ResNet18Embedding 2 | from .kl_vae_hg import VAEEmbedding 3 | from .clip import ClipEmbedding 4 | from .kl_vae_ldm import KLVAEEmbedding 5 | -------------------------------------------------------------------------------- /adversarial/feature_extractors/base_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class BaseEncoder(torch.nn.Module): 5 | def forward(self, images): 6 | raise NotImplementedError("This method should be implemented by subclasses.") 7 | -------------------------------------------------------------------------------- /adversarial/feature_extractors/clip.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoProcessor, CLIPModel 3 | from .base_encoder import BaseEncoder 4 | from torchvision import transforms 5 | 6 | OPENAI_CLIP_MEAN = [0.48145466, 0.4578275, 0.40821073] 7 | OPENAI_CLIP_STD = [0.26862954, 0.26130258, 0.27577711] 8 | 9 | 10 | class ClipEmbedding(BaseEncoder): 11 | def __init__(self): 12 | super(ClipEmbedding, self).__init__() 13 | self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") 14 | self.processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32") 15 | self.normalizer = transforms.Compose( 16 | [ 17 | transforms.Resize((224, 224)), 18 | transforms.Normalize(mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD), 19 | ] 20 | ) 21 | 22 | def forward(self, x): 23 | x = torch.clamp(x, min=0, max=1) 24 | inputs = dict(pixel_values=self.normalizer(x)) 25 | inputs["pixel_values"] = inputs["pixel_values"].cuda() 26 | outputs = self.model.get_image_features(**inputs) 27 | pooled_output = outputs 28 | return pooled_output 29 | -------------------------------------------------------------------------------- /adversarial/feature_extractors/kl_vae_hg.py: -------------------------------------------------------------------------------- 1 | from .base_encoder import BaseEncoder 2 | from diffusers.models import AutoencoderKL 3 | 4 | 5 | class VAEEmbedding(BaseEncoder): 6 | def __init__(self, model_name): 7 | super().__init__() 8 | self.model = AutoencoderKL.from_pretrained(model_name) 9 | 10 | def forward(self, images): 11 | images = 2.0 * images - 1.0 12 | output = self.model.encode(images) 13 | z = output.latent_dist.mode() 14 | return z 15 | -------------------------------------------------------------------------------- /adversarial/feature_extractors/kl_vae_ldm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .base_encoder import BaseEncoder 3 | from omegaconf import OmegaConf 4 | import importlib 5 | 6 | 7 | def instantiate_from_config(config): 8 | if not "target" in config: 9 | if config == "__is_first_stage__": 10 | return None 11 | elif config == "__is_unconditional__": 12 | return None 13 | raise KeyError("Expected key `target` to instantiate.") 14 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 15 | 16 | 17 | def get_obj_from_str(string, reload=False): 18 | module, cls = string.rsplit(".", 1) 19 | if reload: 20 | module_imp = importlib.import_module(module) 21 | importlib.reload(module_imp) 22 | return getattr(importlib.import_module(module, package=None), cls) 23 | 24 | 25 | class KLVAEEmbedding(BaseEncoder): 26 | def __init__(self, model_name): 27 | super().__init__() 28 | self.model = self.get_model(model_name) 29 | 30 | def load_model_from_config(self, config, ckpt): 31 | print(f"Loading model from {ckpt}") 32 | pl_sd = torch.load( 33 | ckpt, 34 | map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"), 35 | ) 36 | sd = pl_sd["state_dict"] 37 | model = instantiate_from_config(config.model) 38 | model.load_state_dict(sd, strict=False) 39 | 40 | delattr(model, "decoder") 41 | 42 | return model 43 | 44 | def get_model(self, name): 45 | config_path = "./models/ldm/" + name + "/config.yaml" 46 | model_path = "./models/ldm/" + name + "/model.ckpt" 47 | config = OmegaConf.load(config_path) 48 | model = self.load_model_from_config(config, model_path) 49 | return model 50 | 51 | def forward(self, images): 52 | images = 2.0 * images - 1.0 53 | output = self.model.encode(images) 54 | z = output.mode() 55 | return z 56 | -------------------------------------------------------------------------------- /adversarial/feature_extractors/resnet18.py: -------------------------------------------------------------------------------- 1 | from .base_encoder import BaseEncoder 2 | import torchvision.models as models 3 | import torch 4 | import torchvision.transforms.functional as TF 5 | 6 | 7 | class ResNet18Embedding(BaseEncoder): 8 | def __init__(self, layer): 9 | super().__init__() 10 | original_model = models.resnet18(pretrained=True) 11 | # Define normalization layers 12 | self.mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).cuda() 13 | self.std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).cuda() 14 | 15 | # Extract the desired layers from the original model 16 | if layer == "layer1": 17 | self.features = torch.nn.Sequential(*list(original_model.children())[:-6]) 18 | elif layer == "layer2": 19 | self.features = torch.nn.Sequential(*list(original_model.children())[:-5]) 20 | elif layer == "layer3": 21 | self.features = torch.nn.Sequential(*list(original_model.children())[:-4]) 22 | elif layer == "layer4": 23 | self.features = torch.nn.Sequential(*list(original_model.children())[:-3]) 24 | elif layer == "last": 25 | self.features = torch.nn.Sequential(*list(original_model.children())[:-1]) 26 | else: 27 | raise ValueError("Invalid layer name") 28 | 29 | def forward(self, images): 30 | # Normalize the input 31 | images = TF.resize(images, [224, 224]) 32 | images = (images - self.mean) / self.std 33 | return self.features(images) 34 | -------------------------------------------------------------------------------- /adversarial/surrogate_models/adv_cls_real_wm_stable_sig.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:e0decb39c80087a016e230489820adaaf66a10a491499b3c8db9ee3ecdf7d813 3 | size 44791101 4 | -------------------------------------------------------------------------------- /adversarial/surrogate_models/adv_cls_real_wm_stegastamp.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:ec1b46d65eda016c02be96ad894a3ed18e55f4739eb4767a8b2087600356cc9b 3 | size 44789582 4 | -------------------------------------------------------------------------------- /adversarial/surrogate_models/adv_cls_real_wm_tree_ring.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:4eebd916caf68453747de116910202f33ac8c8e82eeac5959aa96fe08d2b4fd3 3 | size 44789301 4 | -------------------------------------------------------------------------------- /adversarial/surrogate_models/adv_cls_unwm_wm_stable_sig.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:65c5b6c319644d60c60b70f1c797f0ce215f90dee0fd3ba26aa13a54bba9797b 3 | size 44790968 4 | -------------------------------------------------------------------------------- /adversarial/surrogate_models/adv_cls_unwm_wm_stegastamp.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:30977dd8e0d4b203eed048cf14974cc7ead4e2003bb978cc137f0995d04abde5 3 | size 44790842 4 | -------------------------------------------------------------------------------- /adversarial/surrogate_models/adv_cls_unwm_wm_tree_ring.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:4802905b169c9b7aa9343a8309c6c10c896734b46d22d4b94b3332951e238395 3 | size 44789921 4 | -------------------------------------------------------------------------------- /adversarial/surrogate_models/adv_cls_wm1_wm2_stable_sig.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:254f78a9c6e78753554606288f361fb48aa2be00360d8de2fce160d56a9348d0 3 | size 44789078 4 | -------------------------------------------------------------------------------- /adversarial/surrogate_models/adv_cls_wm1_wm2_stegastamp.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:217a9df3aeb0f7dca511c2a0e42d102cf1c8b774d9162c11be1a1d0ce5682035 3 | size 44789078 4 | -------------------------------------------------------------------------------- /adversarial/surrogate_models/adv_cls_wm1_wm2_tree_ring.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:3f21612af9d7e7bbde1dbdb1004eb7b4cdc0125cae509f1d5d922a9888137e05 3 | size 44788952 4 | -------------------------------------------------------------------------------- /decoders/stable_signature.onnx: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:9b58841ab09f23e89acf5aedade09c7f65908ae33437c5242ad987d99b5cd2c1 3 | size 1228161 4 | -------------------------------------------------------------------------------- /decoders/stega_stamp.onnx: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:369f8134f4e35da9659777f02468f4200c256d14742b9702f2ac28808da675e2 3 | size 218974214 4 | -------------------------------------------------------------------------------- /dev/__init__.py: -------------------------------------------------------------------------------- 1 | from .constants import ( 2 | LIMIT, 3 | SUBSET_LIMIT, 4 | DATASET_NAMES, 5 | WATERMARK_METHODS, 6 | PERFORMANCE_METRICS, 7 | QUALITY_METRICS, 8 | EVALUATION_SETUPS, 9 | GROUND_TRUTH_MESSAGES, 10 | ATTACK_NAMES, 11 | ) 12 | from .io import ( 13 | chmod_group_write, 14 | compare_dicts, 15 | load_json, 16 | save_json, 17 | encode_array_to_string, 18 | decode_array_from_string, 19 | encode_image_to_string, 20 | decode_image_from_string, 21 | ) 22 | from .find import ( 23 | check_file_existence, 24 | existence_operation, 25 | existence_to_indices, 26 | parse_image_dir_path, 27 | get_all_image_dir_paths, 28 | parse_json_path, 29 | get_all_json_paths, 30 | ) 31 | from .parse import ( 32 | get_progress_from_json, 33 | get_example_from_json, 34 | get_distances_from_json, 35 | ) 36 | from .eval import ( 37 | bit_error_rate, 38 | complex_l1, 39 | message_distance, 40 | detection_perforamance, 41 | mean_and_std, 42 | combine_means_and_stds, 43 | ) 44 | from .aggregate import ( 45 | get_performance_from_jsons, 46 | get_performance, 47 | get_single_quality_from_jsons, 48 | get_quality_from_jsons, 49 | get_quality, 50 | clear_aggregated_cache, 51 | ) 52 | from .plot import ( 53 | style_progress_dataframe, 54 | aggregate_comparison_dataframe, 55 | plot_parallel_coordinates, 56 | plot_2d_comparison, 57 | ) 58 | -------------------------------------------------------------------------------- /dev/constants.py: -------------------------------------------------------------------------------- 1 | from .io import decode_array_from_string 2 | 3 | LIMIT, SUBSET_LIMIT = 5000, 1000 4 | 5 | DATASET_NAMES = { 6 | "diffusiondb": "DiffusionDB", 7 | "mscoco": "MS-COCO", 8 | "dalle3": "DALL-E 3", 9 | } 10 | 11 | WATERMARK_METHODS = { 12 | "tree_ring": "Tree-Ring", 13 | "stable_sig": "Stable-Signature", 14 | "stegastamp": "Stega-Stamp", 15 | } 16 | 17 | PERFORMANCE_METRICS = { 18 | "acc_1": "Mean Accuracy", 19 | "auc_1": "AUC", 20 | "low100_1": "TPR@1%FPR", 21 | "low1000_1": "TPR@0.1%FPR", 22 | } 23 | 24 | QUALITY_METRICS = { 25 | "legacy_fid": "Legacy FID", 26 | "clip_fid": "CLIP FID", 27 | "psnr": "PSNR", 28 | "ssim": "SSIM", 29 | "nmi": "Normed Mutual-Info", 30 | "lpips": "LPIPS", 31 | "watson": "Watson-DFT", 32 | "aesthetics": "Delta Aesthetics", 33 | "artifacts": "Delta Artifacts", 34 | "clip_score": "Delta CLIP-Score", 35 | } 36 | 37 | EVALUATION_SETUPS = { 38 | "combined": "Combined", 39 | "removal": "Removal", 40 | "spoofing": "Spoofing", 41 | } 42 | 43 | GROUND_TRUTH_MESSAGES = { 44 | "tree_ring": decode_array_from_string( 45 | "H4sIALRwUmUC/42SvYrCQBSFLW18iam3EZcUFgErkYBFSLcYENZgISgoiMjCVj6FzyFCCoUlTQgrE8TnkcNhcEZIbjhw7x3Ox/zdu1fr+XQ1U/0vr/c5+VDfmx1WKlksp5uup35anX+jLHrVWFHX0FS2cw2h8x+z69KBYs1MxvVjDXkPZpvh/iC8B1SUzKTMWTZTlEV5iBBtzqVAHKJEI4KrphL9OwInU1AzUqbku9W/s+mf1f+93D+5/1Vz8z5j+djIv79qrKj2wFS20x5Al5LZdelApxszGdc/3aBUM9sM9weRasgPmEmZs2zGD/xgmyPanEuB2ObHDBFcNXXMwiE4mYKakTIl363+nU3/rP7v5f7J/a+am/cZewKA1ipNFgUAAA==" 46 | ), 47 | "stable_sig": decode_array_from_string( 48 | "H4sIADtrUmUC/6tWKs5ILEhVsoo2sYjVUUopqQRxlJLy83OUahkYGRkZgBBMMIAAhAvmwUUZIRIgcQBxGJ0kTgAAAA==" 49 | ), 50 | "stegastamp": decode_array_from_string( 51 | "H4sIAGRrUmUC/6tWKs5ILEhVsoo2NDCI1VFKKakE8ZSS8vNzlGoZGBkZGRhAmAFCgxGIC+dBCAaYEEyCEVkOzISrYIToZIAoY2AEAG5jy4ODAAAA" 52 | ), 53 | } 54 | 55 | 56 | ATTACK_NAMES = { 57 | "distortion_single_rotation": "Dist-Rotation", 58 | "distortion_single_resizedcrop": "Dist-RCrop", 59 | "distortion_single_erasing": "Dist-Erase", 60 | "distortion_single_brightness": "Dist-Bright", 61 | "distortion_single_contrast": "Dist-Contrast", 62 | "distortion_single_blurring": "Dist-Blur", 63 | "distortion_single_noise": "Dist-Noise", 64 | "distortion_single_jpeg": "Dist-JPEG", 65 | "distortion_combo_geometric": "Dist-Com-Geo", 66 | "distortion_combo_photometric": "Dist-Com-Photo", 67 | "distortion_combo_degradation": "Dist-Com-Deg", 68 | "distortion_combo_all": "Dist-Com-All", 69 | "regen_diffusion": "Regen-Diffusion", 70 | "regen_diffusion_prompt": "Regen-Diffusion&P", 71 | "regen_vae": "Regen-VAE", 72 | "kl_vae": "Regen-KLVAE", 73 | "2x_regen": "Regen-2xDiffusion", 74 | "4x_regen": "Regen-4xDiffusion", 75 | "4x_regen_bmshj": "Regen-4xVAE", 76 | "4x_regen_kl_vae": "Regen-4xKLVAE", 77 | "adv_emb_resnet18_untg": "AdvEmb-RN18", 78 | "adv_emb_clip_untg_alphaRatio_0.05_step_200": "AdvEmb-CLIP", 79 | "adv_emb_same_vae_untg": "AdvEmb-KLVAE8", 80 | "adv_emb_klf16_vae_untg": "AdvEmb-KLVAE16", 81 | "adv_emb_sdxl_vae_untg": "AdvEmb-SdxlVAE", 82 | "adv_cls_unwm_wm_0.01_50_warm_train3k": "AdvCls-UnWM-WM", 83 | "adv_cls_real_wm_0.01_50_warm": "AdvCls-Real-WM", 84 | "adv_cls_wm1_wm2_0.01_50_warm": "AdvCls-WM1-WM2", 85 | "adv_cls_wm1_wm2_0.04_200_warm": "abandon", 86 | } 87 | -------------------------------------------------------------------------------- /dev/eval.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn import metrics 3 | 4 | 5 | def bit_error_rate(pred, target): 6 | if not pred.dtype == target.dtype == bool: 7 | raise ValueError(f"Cannot compute BER for {pred.dtype} and {target.dtype}") 8 | return np.mean(pred != target) 9 | 10 | 11 | def complex_l1(pred, target): 12 | if not pred.dtype == target.dtype == np.float16: 13 | raise ValueError( 14 | f"Cannot compute Complex L1 for {pred.dtype} and {target.dtype}" 15 | ) 16 | # Cast to float32 to avoid large numerical errors 17 | pred = pred.astype(np.float32).reshape(2, -1) 18 | target = target.astype(np.float32).reshape(2, -1) 19 | return np.sqrt(((pred - target) ** 2).sum(0)).mean() 20 | 21 | 22 | def message_distance(pred, target): 23 | if target.dtype == bool: 24 | return bit_error_rate(pred, target) 25 | elif target.dtype == np.float16: 26 | return complex_l1(pred, target) 27 | else: 28 | raise TypeError 29 | 30 | 31 | def detection_perforamance(original_distances, watermarked_distances): 32 | if not len(original_distances) == len(watermarked_distances): 33 | raise ValueError(f"Length of distances must be equal") 34 | y_true = [0] * len(original_distances) + [1] * len(watermarked_distances) 35 | y_score = (-np.array(original_distances + watermarked_distances)).tolist() 36 | fpr, tpr, thresholds = metrics.roc_curve(y_true, y_score, pos_label=1) 37 | acc_1 = np.max(1 - (fpr + (1 - tpr)) / 2) 38 | auc_1 = metrics.auc(fpr, tpr) 39 | low100_1 = tpr[np.where(fpr < 0.01)[0][-1]] 40 | low1000_1 = tpr[np.where(fpr < 0.001)[0][-1]] 41 | return { 42 | "acc_1": acc_1, 43 | "auc_1": auc_1, 44 | "low100_1": low100_1, 45 | "low1000_1": low1000_1, 46 | } 47 | 48 | 49 | def mean_and_std(values): 50 | if values is None: 51 | return None 52 | return np.mean(values), np.std(values) 53 | 54 | 55 | def combine_means_and_stds(mean_and_std1, mean_and_std2): 56 | if mean_and_std1 is None or mean_and_std2 is None: 57 | return None 58 | mean1, std1 = mean_and_std1 59 | mean2, std2 = mean_and_std2 60 | mean = (mean1 + mean2) / 2 61 | std = np.sqrt((std1**2 + std2**2) / 2) 62 | return mean, std 63 | -------------------------------------------------------------------------------- /dev/io.py: -------------------------------------------------------------------------------- 1 | import os 2 | import warnings 3 | import stat 4 | import orjson 5 | import gzip 6 | import base64 7 | from io import BytesIO 8 | import numpy as np 9 | from PIL import Image 10 | 11 | 12 | def chmod_group_write(path): 13 | if not os.path.exists(path): 14 | raise ValueError(f"Path {path} does not exist") 15 | if os.stat(path).st_uid == os.getuid(): 16 | current_permissions = stat.S_IMODE(os.lstat(path).st_mode) 17 | os.chmod(path, current_permissions | stat.S_IWGRP) 18 | 19 | 20 | def compare_dicts(dict1, dict2): 21 | if dict1.keys() != dict2.keys(): 22 | return False 23 | for key in dict1: 24 | if isinstance(dict1[key], dict) and isinstance(dict2[key], dict): 25 | if not compare_dicts(dict1[key], dict2[key]): 26 | return False 27 | else: 28 | if dict1[key] != dict2[key]: 29 | return False 30 | return True 31 | 32 | 33 | def load_json(filepath): 34 | try: 35 | with open(filepath, "rb") as json_file: 36 | return orjson.loads(json_file.read()) 37 | except orjson.JSONDecodeError: 38 | warnings.warn(f"Found invalid JSON file {filepath}, deleting") 39 | os.remove(filepath) 40 | return None 41 | 42 | 43 | def save_json(data, filepath): 44 | if os.path.exists(filepath) and (existing_data := load_json(filepath)) is not None: 45 | if compare_dicts(data, existing_data): 46 | return 47 | with open(filepath, "wb") as json_file: 48 | json_file.write(orjson.dumps(data)) 49 | chmod_group_write(filepath) 50 | 51 | 52 | def encode_array_to_string(array): 53 | # Convert shape and dtype to byte string using orjson 54 | meta = orjson.dumps({"shape": array.shape, "dtype": str(array.dtype)}) 55 | # Combine metadata and array bytes 56 | combined = meta + b"\x00" + array.tobytes() 57 | # Compress and encode to Base64 58 | compressed = gzip.compress(combined) 59 | return base64.b64encode(compressed).decode("utf-8") 60 | 61 | 62 | def decode_array_from_string(encoded_string): 63 | # Decode from Base64 and decompress 64 | decoded_bytes = base64.b64decode(encoded_string) 65 | decompressed = gzip.decompress(decoded_bytes) 66 | # Split metadata and array data 67 | meta_encoded, array_bytes = decompressed.split(b"\x00", 1) 68 | # Deserialize metadata 69 | meta = orjson.loads(meta_encoded) 70 | shape, dtype = meta["shape"], meta["dtype"] 71 | # Convert bytes back to NumPy array 72 | return np.frombuffer(array_bytes, dtype=dtype).reshape(shape) 73 | 74 | 75 | def encode_image_to_string(image, quality=90): 76 | # Save the image to a byte buffer in JPEG format 77 | buffered = BytesIO() 78 | image.save(buffered, format="JPEG", quality=quality) 79 | # Encode the buffer to a base64 string 80 | return base64.b64encode(gzip.compress(buffered.getvalue())).decode("utf-8") 81 | 82 | 83 | def decode_image_from_string(encoded_string): 84 | # Decode the base64 string to bytes 85 | img_data = gzip.decompress(base64.b64decode(encoded_string)) 86 | # Read the image from bytes 87 | image = Image.open(BytesIO(img_data)) 88 | return image 89 | -------------------------------------------------------------------------------- /dev/parse.py: -------------------------------------------------------------------------------- 1 | from .constants import ( 2 | LIMIT, 3 | SUBSET_LIMIT, 4 | WATERMARK_METHODS, 5 | QUALITY_METRICS, 6 | GROUND_TRUTH_MESSAGES, 7 | ) 8 | from .find import parse_json_path, get_all_json_paths 9 | from .io import load_json, decode_array_from_string, decode_image_from_string 10 | from .eval import message_distance, detection_perforamance 11 | 12 | 13 | def get_progress_from_json(path): 14 | ( 15 | _, 16 | _, 17 | _, 18 | source_name, 19 | result_type, 20 | ) = parse_json_path(path) 21 | data = load_json(path) 22 | if result_type == "status": 23 | return sum([data[str(i)]["exist"] for i in range(LIMIT)]) 24 | elif result_type == "reverse": 25 | return sum([data[str(i)] for i in range(LIMIT)]) 26 | elif result_type == "decode": 27 | for mode in WATERMARK_METHODS.keys(): 28 | if source_name.endswith(mode): 29 | return sum([data[str(i)][mode] is not None for i in range(LIMIT)]) 30 | return sum( 31 | [ 32 | ( 33 | all( 34 | [ 35 | data[str(i)][mode] is not None 36 | for mode in WATERMARK_METHODS.keys() 37 | ] 38 | ) 39 | ) 40 | for i in range(LIMIT) 41 | ] 42 | ) 43 | elif result_type == "metric": 44 | return sum( 45 | [ 46 | ( 47 | all( 48 | [ 49 | data[str(i)][mode] is not None 50 | for mode in QUALITY_METRICS.keys() 51 | ] 52 | ) 53 | ) 54 | for i in range(SUBSET_LIMIT) 55 | ] 56 | ) * (LIMIT // SUBSET_LIMIT) 57 | 58 | 59 | def get_example_from_json(path): 60 | data = load_json(path) 61 | return [ 62 | decode_image_from_string(data[str(i)]["thumbnail"]) for i in [0, 1, 10, 100] 63 | ] 64 | 65 | 66 | def get_distances_from_json(path, mode): 67 | try: 68 | data = load_json(path) 69 | messages = [decode_array_from_string(data[str(i)][mode]) for i in range(LIMIT)] 70 | return [ 71 | message_distance(message, GROUND_TRUTH_MESSAGES[mode]) 72 | for message in messages 73 | ] 74 | except TypeError: 75 | return None 76 | 77 | 78 | def get_metrics_from_json(path, mode): 79 | try: 80 | data = load_json(path) 81 | metrics = [data[str(i)][mode] for i in range(SUBSET_LIMIT)] 82 | if len(metrics) < SUBSET_LIMIT or any([metric is None for metric in metrics]): 83 | return None 84 | return metrics 85 | except KeyError: 86 | return None 87 | -------------------------------------------------------------------------------- /distortions/__init__.py: -------------------------------------------------------------------------------- 1 | from .distortions import ( 2 | distortion_strength_paras, 3 | relative_strength_to_absolute, 4 | apply_distortion, 5 | ) 6 | -------------------------------------------------------------------------------- /guided_diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | from .model_utils import get_default_guided_diffusion_paras, load_guided_diffusion_model 2 | from .generate import guided_ddim_sample, guided_reverse_ddim_sample 3 | -------------------------------------------------------------------------------- /guided_diffusion/dist_util.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | import socket 4 | 5 | import blobfile as bf 6 | from mpi4py import MPI 7 | import torch as th 8 | import torch.distributed as dist 9 | 10 | # Change this to reflect your cluster layout. 11 | # The GPU for a given rank is (rank % GPUS_PER_NODE). 12 | GPUS_PER_NODE = 8 13 | 14 | SETUP_RETRY_COUNT = 3 15 | 16 | 17 | def setup_dist(): 18 | """ 19 | Setup a distributed process group. 20 | """ 21 | if dist.is_initialized(): 22 | return 23 | os.environ["CUDA_VISIBLE_DEVICES"] = f"{MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE}" 24 | 25 | comm = MPI.COMM_WORLD 26 | backend = "gloo" if not th.cuda.is_available() else "nccl" 27 | 28 | if backend == "gloo": 29 | hostname = "localhost" 30 | else: 31 | hostname = socket.gethostbyname(socket.getfqdn()) 32 | os.environ["MASTER_ADDR"] = comm.bcast(hostname, root=0) 33 | os.environ["RANK"] = str(comm.rank) 34 | os.environ["WORLD_SIZE"] = str(comm.size) 35 | 36 | port = comm.bcast(_find_free_port(), root=0) 37 | os.environ["MASTER_PORT"] = str(port) 38 | dist.init_process_group(backend=backend, init_method="env://") 39 | 40 | 41 | def dev(): 42 | """ 43 | Get the device to use for torch.distributed. 44 | """ 45 | if th.cuda.is_available(): 46 | return th.device(f"cuda") 47 | return th.device("cpu") 48 | 49 | 50 | def load_state_dict(path, **kwargs): 51 | """ 52 | Load a PyTorch file without redundant fetches across MPI ranks. 53 | """ 54 | chunk_size = 2**30 # MPI has a relatively small size limit 55 | if MPI.COMM_WORLD.Get_rank() == 0: 56 | with bf.BlobFile(path, "rb") as f: 57 | data = f.read() 58 | num_chunks = len(data) // chunk_size 59 | if len(data) % chunk_size: 60 | num_chunks += 1 61 | MPI.COMM_WORLD.bcast(num_chunks) 62 | for i in range(0, len(data), chunk_size): 63 | MPI.COMM_WORLD.bcast(data[i : i + chunk_size]) 64 | else: 65 | num_chunks = MPI.COMM_WORLD.bcast(None) 66 | data = bytes() 67 | for _ in range(num_chunks): 68 | data += MPI.COMM_WORLD.bcast(None) 69 | 70 | return th.load(io.BytesIO(data), **kwargs) 71 | 72 | 73 | def sync_params(params): 74 | """ 75 | Synchronize a sequence of Tensors across ranks from rank 0. 76 | """ 77 | for p in params: 78 | with th.no_grad(): 79 | dist.broadcast(p, 0) 80 | 81 | 82 | def _find_free_port(): 83 | try: 84 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 85 | s.bind(("", 0)) 86 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 87 | return s.getsockname()[1] 88 | finally: 89 | s.close() 90 | -------------------------------------------------------------------------------- /guided_diffusion/generate.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import torch 3 | from utils import set_random_seed, to_pil, to_tensor 4 | from .script_util import NUM_CLASSES 5 | 6 | 7 | # Guided diffusion with DDIM sampler 8 | def guided_ddim_sample( 9 | model, 10 | diffusion, 11 | labels, 12 | image_size, 13 | diffusion_seed, 14 | init_latent=None, 15 | progressive=False, 16 | return_image=False, 17 | ): 18 | # Diffusion seed is the random seed for diffusion sampling 19 | set_random_seed(diffusion_seed) 20 | # For guided diffusion, prompts are class ids 21 | assert isinstance(labels, list) and all( 22 | isinstance(label, int) and 0 <= label < NUM_CLASSES for label in labels 23 | ) 24 | # Device and shape 25 | device = next(model.parameters()).device 26 | shape = (len(labels), 3, image_size, image_size) 27 | # The random initial latent is determined by the diffusion seed, so no need to keep it 28 | if init_latent is None: 29 | init_latent = torch.randn(*shape, device=device) 30 | # Diffusion 31 | if not progressive: 32 | output = diffusion.ddim_sample_loop( 33 | model=model, 34 | shape=shape, 35 | noise=init_latent, 36 | model_kwargs=dict(y=torch.tensor(labels, device=device)), 37 | device=device, 38 | return_image=return_image, 39 | ) 40 | return output 41 | else: 42 | output = [] 43 | for sample in diffusion.ddim_sample_loop_progressive( 44 | model=model, 45 | shape=shape, 46 | noise=init_latent, 47 | model_kwargs=dict(y=torch.tensor(labels, device=device)), 48 | device=device, 49 | ): 50 | if not return_image: 51 | output.append(sample["sample"]) 52 | else: 53 | output.append(to_pil(sample["sample"])) 54 | return output 55 | 56 | 57 | # Reverse guided diffusion with DDIM sampler 58 | def guided_reverse_ddim_sample( 59 | model, 60 | diffusion, 61 | images, 62 | image_size, 63 | default_labels=0, 64 | progressive=False, 65 | return_image=False, 66 | ): 67 | # Reverse diffusion of DDIM smapling is deterministic, so this line has no effect 68 | set_random_seed(0) 69 | # Device and shape 70 | device = next(model.parameters()).device 71 | shape = (len(images), 3, image_size, image_size) 72 | # If default labels is a single int, repeat it for all images 73 | if isinstance(default_labels, int): 74 | default_labels = [default_labels] * len(images) 75 | # Check whether the inputs are PIL images 76 | if isinstance(images[0], Image.Image): 77 | images = to_tensor(images, norm_type="naive").to(device) 78 | # Reversed diffusion 79 | if not progressive: 80 | output = diffusion.ddim_reverse_sample_loop( 81 | model=model, 82 | shape=shape, 83 | image=images, 84 | # Reverse diffusion does not depends on the labels, thus pass in dummy labels 85 | model_kwargs=dict(y=torch.tensor(default_labels, device=device)), 86 | device=device, 87 | ) 88 | if not return_image: 89 | return output 90 | else: 91 | return to_pil(output) 92 | else: 93 | output = [] 94 | for sample in diffusion.ddim_reverse_sample_loop_progressive( 95 | model=model, 96 | shape=shape, 97 | image=images, 98 | # Reverse diffusion does not depends on the labels, thus pass in dummy labels 99 | model_kwargs=dict(y=torch.tensor(default_labels, device=device)), 100 | device=device, 101 | ): 102 | if not return_image: 103 | output.append(sample["sample"]) 104 | else: 105 | output.append(to_pil(sample["sample"])) 106 | return output 107 | -------------------------------------------------------------------------------- /guided_diffusion/losses.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch as th 3 | 4 | 5 | def normal_kl(mean1, logvar1, mean2, logvar2): 6 | """ 7 | Compute the KL divergence between two gaussians. 8 | 9 | Shapes are automatically broadcasted, so batches can be compared to 10 | scalars, among other use cases. 11 | """ 12 | tensor = None 13 | for obj in (mean1, logvar1, mean2, logvar2): 14 | if isinstance(obj, th.Tensor): 15 | tensor = obj 16 | break 17 | assert tensor is not None, "at least one argument must be a Tensor" 18 | 19 | # Force variances to be Tensors. Broadcasting helps convert scalars to 20 | # Tensors, but it does not work for th.exp(). 21 | logvar1, logvar2 = [ 22 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) 23 | for x in (logvar1, logvar2) 24 | ] 25 | 26 | return 0.5 * ( 27 | -1.0 28 | + logvar2 29 | - logvar1 30 | + th.exp(logvar1 - logvar2) 31 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2) 32 | ) 33 | 34 | 35 | def approx_standard_normal_cdf(x): 36 | """ 37 | A fast approximation of the cumulative distribution function of the 38 | standard normal. 39 | """ 40 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) 41 | 42 | 43 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 44 | """ 45 | Compute the log-likelihood of a Gaussian distribution discretizing to a 46 | given image. 47 | 48 | :param x: the target images. It is assumed that this was uint8 values, 49 | rescaled to the range [-1, 1]. 50 | :param means: the Gaussian mean Tensor. 51 | :param log_scales: the Gaussian log stddev Tensor. 52 | :return: a tensor like x of log probabilities (in nats). 53 | """ 54 | assert x.shape == means.shape == log_scales.shape 55 | centered_x = x - means 56 | inv_stdv = th.exp(-log_scales) 57 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 58 | cdf_plus = approx_standard_normal_cdf(plus_in) 59 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 60 | cdf_min = approx_standard_normal_cdf(min_in) 61 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 62 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 63 | cdf_delta = cdf_plus - cdf_min 64 | log_probs = th.where( 65 | x < -0.999, 66 | log_cdf_plus, 67 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 68 | ) 69 | assert log_probs.shape == x.shape 70 | return log_probs 71 | -------------------------------------------------------------------------------- /guided_diffusion/model_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from .script_util import ( 4 | model_and_diffusion_defaults, 5 | create_model_and_diffusion, 6 | ) 7 | 8 | 9 | # Diffusion model parameters 10 | guided_diffusion_64x64_paras = dict( 11 | image_size=64, 12 | num_channels=192, 13 | num_res_blocks=3, 14 | num_head_channels=64, 15 | attention_resolutions="32,16,8", 16 | dropout=0.1, 17 | class_cond=True, 18 | resblock_updown=True, 19 | use_fp16=True, 20 | use_new_attention_order=True, 21 | learn_sigma=True, 22 | diffusion_steps=1000, 23 | noise_schedule="cosine", 24 | timestep_respacing="ddim50", 25 | use_scale_shift_norm=True, 26 | ) 27 | guided_diffusion_256x256_paras = dict( 28 | image_size=256, 29 | num_channels=256, 30 | num_res_blocks=2, 31 | num_head_channels=64, 32 | attention_resolutions="32,16,8", 33 | class_cond=True, 34 | resblock_updown=True, 35 | use_fp16=True, 36 | use_new_attention_order=False, 37 | learn_sigma=True, 38 | diffusion_steps=1000, 39 | noise_schedule="linear", 40 | timestep_respacing="ddim50", 41 | use_scale_shift_norm=True, 42 | ) 43 | 44 | 45 | # Get the default parameters for guided diffusion 46 | def get_default_guided_diffusion_paras(image_size): 47 | # Support two image sizes 48 | assert image_size in [64, 256] 49 | if image_size == 64: 50 | return guided_diffusion_64x64_paras 51 | else: 52 | return guided_diffusion_256x256_paras 53 | 54 | 55 | # Load guided diffusion model and weights 56 | def load_guided_diffusion_model(image_size, device): 57 | # Support two image sizes 58 | assert image_size in [64, 256] 59 | paras = model_and_diffusion_defaults() 60 | # Update with default parameters, see https://github.com/openai/guided-diffusion 61 | paras.update(get_default_guided_diffusion_paras(image_size)) 62 | # Initilaize model and load weights 63 | model, diffusion = create_model_and_diffusion(**paras) 64 | model.load_state_dict( 65 | torch.load( 66 | os.path.join( 67 | os.environ.get("MODEL_DIR"), 68 | f"guided-diffusion/{image_size}x{image_size}_diffusion.pt", 69 | ), 70 | map_location=device, 71 | ) 72 | ) 73 | model.to(device) 74 | # Convert to FP16 75 | if paras["use_fp16"]: 76 | model.convert_to_fp16() 77 | # Set eval flag 78 | model.eval() 79 | return model, diffusion 80 | -------------------------------------------------------------------------------- /ldm/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/ldm/data/__init__.py -------------------------------------------------------------------------------- /ldm/data/base.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset 3 | 4 | 5 | class Txt2ImgIterableBaseDataset(IterableDataset): 6 | """ 7 | Define an interface to make the IterableDatasets for text2img data chainable 8 | """ 9 | 10 | def __init__(self, num_records=0, valid_ids=None, size=256): 11 | super().__init__() 12 | self.num_records = num_records 13 | self.valid_ids = valid_ids 14 | self.sample_ids = valid_ids 15 | self.size = size 16 | 17 | print(f"{self.__class__.__name__} dataset contains {self.__len__()} examples.") 18 | 19 | def __len__(self): 20 | return self.num_records 21 | 22 | @abstractmethod 23 | def __iter__(self): 24 | pass 25 | -------------------------------------------------------------------------------- /ldm/data/lsun.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import PIL 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | from torchvision import transforms 7 | 8 | 9 | class LSUNBase(Dataset): 10 | def __init__( 11 | self, txt_file, data_root, size=None, interpolation="bicubic", flip_p=0.5 12 | ): 13 | self.data_paths = txt_file 14 | self.data_root = data_root 15 | with open(self.data_paths, "r") as f: 16 | self.image_paths = f.read().splitlines() 17 | self._length = len(self.image_paths) 18 | self.labels = { 19 | "relative_file_path_": [l for l in self.image_paths], 20 | "file_path_": [os.path.join(self.data_root, l) for l in self.image_paths], 21 | } 22 | 23 | self.size = size 24 | self.interpolation = { 25 | "linear": PIL.Image.LINEAR, 26 | "bilinear": PIL.Image.BILINEAR, 27 | "bicubic": PIL.Image.BICUBIC, 28 | "lanczos": PIL.Image.LANCZOS, 29 | }[interpolation] 30 | self.flip = transforms.RandomHorizontalFlip(p=flip_p) 31 | 32 | def __len__(self): 33 | return self._length 34 | 35 | def __getitem__(self, i): 36 | example = dict((k, self.labels[k][i]) for k in self.labels) 37 | image = Image.open(example["file_path_"]) 38 | if not image.mode == "RGB": 39 | image = image.convert("RGB") 40 | 41 | # default to score-sde preprocessing 42 | img = np.array(image).astype(np.uint8) 43 | crop = min(img.shape[0], img.shape[1]) 44 | ( 45 | h, 46 | w, 47 | ) = ( 48 | img.shape[0], 49 | img.shape[1], 50 | ) 51 | img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2] 52 | 53 | image = Image.fromarray(img) 54 | if self.size is not None: 55 | image = image.resize((self.size, self.size), resample=self.interpolation) 56 | 57 | image = self.flip(image) 58 | image = np.array(image).astype(np.uint8) 59 | example["image"] = (image / 127.5 - 1.0).astype(np.float32) 60 | return example 61 | 62 | 63 | class LSUNChurchesTrain(LSUNBase): 64 | def __init__(self, **kwargs): 65 | super().__init__( 66 | txt_file="data/lsun/church_outdoor_train.txt", 67 | data_root="data/lsun/churches", 68 | **kwargs 69 | ) 70 | 71 | 72 | class LSUNChurchesValidation(LSUNBase): 73 | def __init__(self, flip_p=0.0, **kwargs): 74 | super().__init__( 75 | txt_file="data/lsun/church_outdoor_val.txt", 76 | data_root="data/lsun/churches", 77 | flip_p=flip_p, 78 | **kwargs 79 | ) 80 | 81 | 82 | class LSUNBedroomsTrain(LSUNBase): 83 | def __init__(self, **kwargs): 84 | super().__init__( 85 | txt_file="data/lsun/bedrooms_train.txt", 86 | data_root="data/lsun/bedrooms", 87 | **kwargs 88 | ) 89 | 90 | 91 | class LSUNBedroomsValidation(LSUNBase): 92 | def __init__(self, flip_p=0.0, **kwargs): 93 | super().__init__( 94 | txt_file="data/lsun/bedrooms_val.txt", 95 | data_root="data/lsun/bedrooms", 96 | flip_p=flip_p, 97 | **kwargs 98 | ) 99 | 100 | 101 | class LSUNCatsTrain(LSUNBase): 102 | def __init__(self, **kwargs): 103 | super().__init__( 104 | txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs 105 | ) 106 | 107 | 108 | class LSUNCatsValidation(LSUNBase): 109 | def __init__(self, flip_p=0.0, **kwargs): 110 | super().__init__( 111 | txt_file="data/lsun/cat_val.txt", 112 | data_root="data/lsun/cats", 113 | flip_p=flip_p, 114 | **kwargs 115 | ) 116 | -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/ldm/modules/diffusionmodules/__init__.py -------------------------------------------------------------------------------- /ldm/modules/distributions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/ldm/modules/distributions/__init__.py -------------------------------------------------------------------------------- /ldm/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to( 34 | device=self.parameters.device 35 | ) 36 | 37 | def sample(self): 38 | x = self.mean + self.std * torch.randn(self.mean.shape).to( 39 | device=self.parameters.device 40 | ) 41 | return x 42 | 43 | def kl(self, other=None): 44 | if self.deterministic: 45 | return torch.Tensor([0.0]) 46 | else: 47 | if other is None: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, 50 | dim=[1, 2, 3], 51 | ) 52 | else: 53 | return 0.5 * torch.sum( 54 | torch.pow(self.mean - other.mean, 2) / other.var 55 | + self.var / other.var 56 | - 1.0 57 | - self.logvar 58 | + other.logvar, 59 | dim=[1, 2, 3], 60 | ) 61 | 62 | def nll(self, sample, dims=[1, 2, 3]): 63 | if self.deterministic: 64 | return torch.Tensor([0.0]) 65 | logtwopi = np.log(2.0 * np.pi) 66 | return 0.5 * torch.sum( 67 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 68 | dim=dims, 69 | ) 70 | 71 | def mode(self): 72 | return self.mean 73 | 74 | 75 | def normal_kl(mean1, logvar1, mean2, logvar2): 76 | """ 77 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 78 | Compute the KL divergence between two gaussians. 79 | Shapes are automatically broadcasted, so batches can be compared to 80 | scalars, among other use cases. 81 | """ 82 | tensor = None 83 | for obj in (mean1, logvar1, mean2, logvar2): 84 | if isinstance(obj, torch.Tensor): 85 | tensor = obj 86 | break 87 | assert tensor is not None, "at least one argument must be a Tensor" 88 | 89 | # Force variances to be Tensors. Broadcasting helps convert scalars to 90 | # Tensors, but it does not work for torch.exp(). 91 | logvar1, logvar2 = [ 92 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 93 | for x in (logvar1, logvar2) 94 | ] 95 | 96 | return 0.5 * ( 97 | -1.0 98 | + logvar2 99 | - logvar1 100 | + torch.exp(logvar1 - logvar2) 101 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 102 | ) 103 | -------------------------------------------------------------------------------- /ldm/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError("Decay must be between 0 and 1") 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer( 14 | "num_updates", 15 | ( 16 | torch.tensor(0, dtype=torch.int) 17 | if use_num_upates 18 | else torch.tensor(-1, dtype=torch.int) 19 | ), 20 | ) 21 | 22 | for name, p in model.named_parameters(): 23 | if p.requires_grad: 24 | # remove as '.'-character is not allowed in buffers 25 | s_name = name.replace(".", "") 26 | self.m_name2s_name.update({name: s_name}) 27 | self.register_buffer(s_name, p.clone().detach().data) 28 | 29 | self.collected_params = [] 30 | 31 | def forward(self, model): 32 | decay = self.decay 33 | 34 | if self.num_updates >= 0: 35 | self.num_updates += 1 36 | decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) 37 | 38 | one_minus_decay = 1.0 - decay 39 | 40 | with torch.no_grad(): 41 | m_param = dict(model.named_parameters()) 42 | shadow_params = dict(self.named_buffers()) 43 | 44 | for key in m_param: 45 | if m_param[key].requires_grad: 46 | sname = self.m_name2s_name[key] 47 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 48 | shadow_params[sname].sub_( 49 | one_minus_decay * (shadow_params[sname] - m_param[key]) 50 | ) 51 | else: 52 | assert not key in self.m_name2s_name 53 | 54 | def copy_to(self, model): 55 | m_param = dict(model.named_parameters()) 56 | shadow_params = dict(self.named_buffers()) 57 | for key in m_param: 58 | if m_param[key].requires_grad: 59 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 60 | else: 61 | assert not key in self.m_name2s_name 62 | 63 | def store(self, parameters): 64 | self.collected_params = [param.clone() for param in parameters] 65 | 66 | def restore(self, parameters): 67 | for c_param, param in zip(self.collected_params, parameters): 68 | param.data.copy_(c_param.data) 69 | -------------------------------------------------------------------------------- /ldm/modules/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/ldm/modules/encoders/__init__.py -------------------------------------------------------------------------------- /ldm/modules/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator 2 | -------------------------------------------------------------------------------- /metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .distributional import compute_fid 2 | from .image import ( 3 | compute_mse, 4 | compute_psnr, 5 | compute_ssim, 6 | compute_nmi, 7 | compute_mse_repeated, 8 | compute_psnr_repeated, 9 | compute_ssim_repeated, 10 | compute_nmi_repeated, 11 | compute_image_distance_repeated, 12 | ) 13 | from .perceptual import ( 14 | load_perceptual_models, 15 | compute_lpips, 16 | compute_watson, 17 | compute_lpips_repeated, 18 | compute_watson_repeated, 19 | compute_perceptual_metric_repeated, 20 | ) 21 | from .aesthetics import ( 22 | load_aesthetics_and_artifacts_models, 23 | compute_aesthetics_and_artifacts_scores, 24 | ) 25 | from .clip import load_open_clip_model_preprocess_and_tokenizer, compute_clip_score 26 | from .prompt import ( 27 | load_perplexity_model_and_tokenizer, 28 | compute_prompt_perplexity, 29 | ) 30 | -------------------------------------------------------------------------------- /metrics/aesthetics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from PIL import Image 3 | from transformers import CLIPModel, CLIPProcessor 4 | from .aesthetics_scorer import preprocess, load_model 5 | 6 | 7 | def load_aesthetics_and_artifacts_models(device=torch.device("cuda")): 8 | model = CLIPModel.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") 9 | vision_model = model.vision_model 10 | vision_model.to(device) 11 | del model 12 | clip_processor = CLIPProcessor.from_pretrained( 13 | "laion/CLIP-ViT-H-14-laion2B-s32B-b79K" 14 | ) 15 | rating_model = load_model("aesthetics_scorer_rating_openclip_vit_h_14").to(device) 16 | artifacts_model = load_model("aesthetics_scorer_artifacts_openclip_vit_h_14").to( 17 | device 18 | ) 19 | return vision_model, clip_processor, rating_model, artifacts_model 20 | 21 | 22 | def compute_aesthetics_and_artifacts_scores( 23 | images, models, device=torch.device("cuda") 24 | ): 25 | vision_model, clip_processor, rating_model, artifacts_model = models 26 | 27 | inputs = clip_processor(images=images, return_tensors="pt").to(device) 28 | with torch.no_grad(): 29 | vision_output = vision_model(**inputs) 30 | pooled_output = vision_output.pooler_output 31 | embedding = preprocess(pooled_output) 32 | with torch.no_grad(): 33 | rating = rating_model(embedding) 34 | artifact = artifacts_model(embedding) 35 | return ( 36 | rating.detach().cpu().numpy().flatten().tolist(), 37 | artifact.detach().cpu().numpy().flatten().tolist(), 38 | ) 39 | -------------------------------------------------------------------------------- /metrics/aesthetics_scorer/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | From https://github.com/kenjiqq/aesthetics-scorer#validation-split-of-diffusiondb-dataset 3 | """ 4 | from .model import preprocess, load_model 5 | -------------------------------------------------------------------------------- /metrics/aesthetics_scorer/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import json 4 | import os 5 | import inspect 6 | 7 | 8 | class AestheticScorer(nn.Module): 9 | def __init__( 10 | self, 11 | input_size=0, 12 | use_activation=False, 13 | dropout=0.2, 14 | config=None, 15 | hidden_dim=1024, 16 | reduce_dims=False, 17 | output_activation=None, 18 | ): 19 | super().__init__() 20 | self.config = { 21 | "input_size": input_size, 22 | "use_activation": use_activation, 23 | "dropout": dropout, 24 | "hidden_dim": hidden_dim, 25 | "reduce_dims": reduce_dims, 26 | "output_activation": output_activation, 27 | } 28 | if config != None: 29 | self.config.update(config) 30 | 31 | layers = [ 32 | nn.Linear(self.config["input_size"], self.config["hidden_dim"]), 33 | nn.ReLU() if self.config["use_activation"] else None, 34 | nn.Dropout(self.config["dropout"]), 35 | nn.Linear( 36 | self.config["hidden_dim"], 37 | round(self.config["hidden_dim"] / (2 if reduce_dims else 1)), 38 | ), 39 | nn.ReLU() if self.config["use_activation"] else None, 40 | nn.Dropout(self.config["dropout"]), 41 | nn.Linear( 42 | round(self.config["hidden_dim"] / (2 if reduce_dims else 1)), 43 | round(self.config["hidden_dim"] / (4 if reduce_dims else 1)), 44 | ), 45 | nn.ReLU() if self.config["use_activation"] else None, 46 | nn.Dropout(self.config["dropout"]), 47 | nn.Linear( 48 | round(self.config["hidden_dim"] / (4 if reduce_dims else 1)), 49 | round(self.config["hidden_dim"] / (8 if reduce_dims else 1)), 50 | ), 51 | nn.ReLU() if self.config["use_activation"] else None, 52 | nn.Linear(round(self.config["hidden_dim"] / (8 if reduce_dims else 1)), 1), 53 | ] 54 | if self.config["output_activation"] == "sigmoid": 55 | layers.append(nn.Sigmoid()) 56 | layers = [x for x in layers if x is not None] 57 | self.layers = nn.Sequential(*layers) 58 | 59 | def forward(self, x): 60 | if self.config["output_activation"] == "sigmoid": 61 | upper, lower = 10, 1 62 | scale = upper - lower 63 | return (self.layers(x) * scale) + lower 64 | else: 65 | return self.layers(x) 66 | 67 | def save(self, save_name): 68 | split_name = os.path.splitext(save_name) 69 | with open(f"{split_name[0]}.config", "w") as outfile: 70 | outfile.write(json.dumps(self.config, indent=4)) 71 | 72 | for i in range( 73 | 6 74 | ): # saving sometiles fails, so retry 5 times, might be windows issue 75 | try: 76 | torch.save(self.state_dict(), save_name) 77 | break 78 | except RuntimeError as e: 79 | # check if error contains string "File" 80 | if "cannot be opened" in str(e) and i < 5: 81 | print("Model save failed, retrying...") 82 | else: 83 | raise e 84 | 85 | 86 | def preprocess(embeddings): 87 | return embeddings / embeddings.norm(p=2, dim=-1, keepdim=True) 88 | 89 | 90 | def load_model(weight_name, device="cuda" if torch.cuda.is_available() else "cpu"): 91 | weight_folder = os.path.abspath( 92 | os.path.join( 93 | inspect.getfile(load_model), 94 | "../weights", 95 | ) 96 | ) 97 | weight_path = os.path.join(weight_folder, f"{weight_name}.pth") 98 | config_path = os.path.join(weight_folder, f"{weight_name}.config") 99 | with open(config_path, "r") as config_file: 100 | config = json.load(config_file) 101 | model = AestheticScorer(config=config) 102 | model.load_state_dict(torch.load(weight_path, map_location=device)) 103 | model.eval() 104 | return model 105 | -------------------------------------------------------------------------------- /metrics/clean_fid/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | From https://github.com/GaParmar/clean-fid/tree/main 3 | """ 4 | -------------------------------------------------------------------------------- /metrics/clean_fid/clip_features.py: -------------------------------------------------------------------------------- 1 | # pip install git+https://github.com/openai/CLIP.git 2 | import pdb 3 | from PIL import Image 4 | import numpy as np 5 | import torch 6 | import torchvision.transforms as transforms 7 | import clip 8 | from .fid import compute_fid 9 | 10 | 11 | def img_preprocess_clip(img_np): 12 | x = Image.fromarray(img_np.astype(np.uint8)).convert("RGB") 13 | T = transforms.Compose( 14 | [ 15 | transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC), 16 | transforms.CenterCrop(224), 17 | ] 18 | ) 19 | return np.asarray(T(x)).clip(0, 255).astype(np.uint8) 20 | 21 | 22 | class CLIP_fx: 23 | def __init__(self, name="ViT-B/32", device="cuda"): 24 | self.model, _ = clip.load(name, device=device) 25 | self.model.eval() 26 | self.name = "clip_" + name.lower().replace("-", "_").replace("/", "_") 27 | 28 | def __call__(self, img_t): 29 | img_x = img_t / 255.0 30 | T_norm = transforms.Normalize( 31 | (0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711) 32 | ) 33 | img_x = T_norm(img_x) 34 | assert torch.is_tensor(img_x) 35 | if len(img_x.shape) == 3: 36 | img_x = img_x.unsqueeze(0) 37 | B, C, H, W = img_x.shape 38 | with torch.no_grad(): 39 | z = self.model.encode_image(img_x) 40 | return z 41 | -------------------------------------------------------------------------------- /metrics/clean_fid/downloads_helper.py: -------------------------------------------------------------------------------- 1 | import os 2 | import urllib.request 3 | import requests 4 | import shutil 5 | 6 | 7 | inception_url = "https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt" 8 | 9 | 10 | """ 11 | Download the pretrined inception weights if it does not exists 12 | ARGS: 13 | fpath - output folder path 14 | """ 15 | 16 | 17 | def check_download_inception(fpath="./"): 18 | inception_path = os.path.join(fpath, "inception-2015-12-05.pt") 19 | if not os.path.exists(inception_path): 20 | # download the file 21 | with urllib.request.urlopen(inception_url) as response, open( 22 | inception_path, "wb" 23 | ) as f: 24 | shutil.copyfileobj(response, f) 25 | return inception_path 26 | 27 | 28 | """ 29 | Download any url if it does not exist 30 | ARGS: 31 | local_folder - output folder path 32 | url - the weburl to download 33 | """ 34 | 35 | 36 | def check_download_url(local_folder, url): 37 | name = os.path.basename(url) 38 | local_path = os.path.join(local_folder, name) 39 | if not os.path.exists(local_path): 40 | os.makedirs(local_folder, exist_ok=True) 41 | print(f"downloading statistics to {local_path}") 42 | with urllib.request.urlopen(url) as response, open(local_path, "wb") as f: 43 | shutil.copyfileobj(response, f) 44 | return local_path 45 | 46 | 47 | """ 48 | Download a file from google drive 49 | ARGS: 50 | file_id - id of the google drive file 51 | out_path - output folder path 52 | """ 53 | 54 | 55 | def download_google_drive(file_id, out_path): 56 | def get_confirm_token(response): 57 | for key, value in response.cookies.items(): 58 | if key.startswith("download_warning"): 59 | return value 60 | return None 61 | 62 | URL = "https://drive.google.com/uc?export=download" 63 | session = requests.Session() 64 | response = session.get(URL, params={"id": file_id}, stream=True) 65 | token = get_confirm_token(response) 66 | 67 | if token: 68 | params = {"id": file_id, "confirm": token} 69 | response = session.get(URL, params=params, stream=True) 70 | 71 | CHUNK_SIZE = 32768 72 | with open(out_path, "wb") as f: 73 | for chunk in response.iter_content(CHUNK_SIZE): 74 | if chunk: 75 | f.write(chunk) 76 | -------------------------------------------------------------------------------- /metrics/clean_fid/features.py: -------------------------------------------------------------------------------- 1 | """ 2 | helpers for extracting features from image 3 | """ 4 | import os 5 | import platform 6 | import numpy as np 7 | import torch 8 | from torch.hub import get_dir 9 | from .downloads_helper import check_download_url 10 | from .inception_pytorch import InceptionV3 11 | from .inception_torchscript import InceptionV3W 12 | 13 | 14 | """ 15 | returns a functions that takes an image in range [0,255] 16 | and outputs a feature embedding vector 17 | """ 18 | 19 | 20 | def feature_extractor( 21 | name="torchscript_inception", 22 | device=torch.device("cuda"), 23 | resize_inside=False, 24 | use_dataparallel=True, 25 | ): 26 | if name == "torchscript_inception": 27 | path = "./" if platform.system() == "Windows" else "/tmp" 28 | model = InceptionV3W(path, download=True, resize_inside=resize_inside).to( 29 | device 30 | ) 31 | model.eval() 32 | if use_dataparallel: 33 | model = torch.nn.DataParallel(model) 34 | 35 | def model_fn(x): 36 | return model(x) 37 | 38 | elif name == "pytorch_inception": 39 | model = InceptionV3(output_blocks=[3], resize_input=False).to(device) 40 | model.eval() 41 | if use_dataparallel: 42 | model = torch.nn.DataParallel(model) 43 | 44 | def model_fn(x): 45 | return model(x / 255)[0].squeeze(-1).squeeze(-1) 46 | 47 | else: 48 | raise ValueError(f"{name} feature extractor not implemented") 49 | return model_fn 50 | 51 | 52 | """ 53 | Build a feature extractor for each of the modes 54 | """ 55 | 56 | 57 | def build_feature_extractor(mode, device=torch.device("cuda"), use_dataparallel=True): 58 | if mode == "legacy_pytorch": 59 | feat_model = feature_extractor( 60 | name="pytorch_inception", 61 | resize_inside=False, 62 | device=device, 63 | use_dataparallel=use_dataparallel, 64 | ) 65 | elif mode == "legacy_tensorflow": 66 | feat_model = feature_extractor( 67 | name="torchscript_inception", 68 | resize_inside=True, 69 | device=device, 70 | use_dataparallel=use_dataparallel, 71 | ) 72 | elif mode == "clean": 73 | feat_model = feature_extractor( 74 | name="torchscript_inception", 75 | resize_inside=False, 76 | device=device, 77 | use_dataparallel=use_dataparallel, 78 | ) 79 | return feat_model 80 | 81 | 82 | """ 83 | Load precomputed reference statistics for commonly used datasets 84 | """ 85 | 86 | 87 | def get_reference_statistics( 88 | name, 89 | res, 90 | mode="clean", 91 | model_name="inception_v3", 92 | seed=0, 93 | split="test", 94 | metric="FID", 95 | ): 96 | base_url = "https://www.cs.cmu.edu/~clean-fid/stats/" 97 | if split == "custom": 98 | res = "na" 99 | if model_name == "inception_v3": 100 | model_modifier = "" 101 | else: 102 | model_modifier = "_" + model_name 103 | if metric == "FID": 104 | rel_path = (f"{name}_{mode}{model_modifier}_{split}_{res}.npz") 105 | url = f"{base_url}/{rel_path}" 106 | stats_folder = os.path.join(get_dir(), "fid_stats") 107 | fpath = check_download_url(local_folder=stats_folder, url=url) 108 | stats = np.load(fpath) 109 | mu, sigma = stats["mu"], stats["sigma"] 110 | return mu, sigma 111 | elif metric == "KID": 112 | rel_path = (f"{name}_{mode}{model_modifier}_{split}_{res}_kid.npz") 113 | url = f"{base_url}/{rel_path}" 114 | stats_folder = os.path.join(get_dir(), "fid_stats") 115 | fpath = check_download_url(local_folder=stats_folder, url=url) 116 | stats = np.load(fpath) 117 | return stats["feats"] 118 | -------------------------------------------------------------------------------- /metrics/clean_fid/inception_torchscript.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import contextlib 5 | from .downloads_helper import * 6 | 7 | 8 | @contextlib.contextmanager 9 | def disable_gpu_fuser_on_pt19(): 10 | # On PyTorch 1.9 a CUDA fuser bug prevents the Inception JIT model to run. See 11 | # https://github.com/GaParmar/clean-fid/issues/5 12 | # https://github.com/pytorch/pytorch/issues/64062 13 | if torch.__version__.startswith("1.9."): 14 | old_val = torch._C._jit_can_fuse_on_gpu() 15 | torch._C._jit_override_can_fuse_on_gpu(False) 16 | yield 17 | if torch.__version__.startswith("1.9."): 18 | torch._C._jit_override_can_fuse_on_gpu(old_val) 19 | 20 | 21 | class InceptionV3W(nn.Module): 22 | """ 23 | Wrapper around Inception V3 torchscript model provided here 24 | https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt 25 | 26 | path: locally saved inception weights 27 | """ 28 | 29 | def __init__(self, path, download=True, resize_inside=False): 30 | super(InceptionV3W, self).__init__() 31 | # download the network if it is not present at the given directory 32 | # use the current directory by default 33 | if download: 34 | check_download_inception(fpath=path) 35 | path = os.path.join(path, "inception-2015-12-05.pt") 36 | self.base = torch.jit.load(path).eval() 37 | self.layers = self.base.layers 38 | self.resize_inside = resize_inside 39 | 40 | """ 41 | Get the inception features without resizing 42 | x: Image with values in range [0,255] 43 | """ 44 | 45 | def forward(self, x): 46 | with disable_gpu_fuser_on_pt19(): 47 | bs = x.shape[0] 48 | if self.resize_inside: 49 | features = self.base(x, return_features=True).view((bs, 2048)) 50 | else: 51 | # make sure it is resized already 52 | assert (x.shape[2] == 299) and (x.shape[3] == 299) 53 | # apply normalization 54 | x1 = x - 128 55 | x2 = x1 / 128 56 | features = self.layers.forward( 57 | x2, 58 | ).view((bs, 2048)) 59 | return features 60 | -------------------------------------------------------------------------------- /metrics/clean_fid/leaderboard.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | import shutil 4 | import urllib.request 5 | 6 | 7 | def get_score( 8 | model_name=None, 9 | dataset_name=None, 10 | dataset_res=None, 11 | dataset_split=None, 12 | task_name=None, 13 | ): 14 | # download the csv file from server 15 | url = "https://www.cs.cmu.edu/~clean-fid/files/leaderboard.csv" 16 | local_path = "/tmp/leaderboard.csv" 17 | with urllib.request.urlopen(url) as response, open(local_path, "wb") as f: 18 | shutil.copyfileobj(response, f) 19 | 20 | d_field2idx = {} 21 | l_matches = [] 22 | with open(local_path, "r") as f: 23 | csvreader = csv.reader(f) 24 | l_fields = next(csvreader) 25 | for idx, val in enumerate(l_fields): 26 | d_field2idx[val.strip()] = idx 27 | # iterate through all rows 28 | for row in csvreader: 29 | # skip empty rows 30 | if len(row) == 0: 31 | continue 32 | # skip if the filter doesn't match 33 | if model_name is not None and ( 34 | row[d_field2idx["model_name"]].strip() != model_name 35 | ): 36 | continue 37 | if dataset_name is not None and ( 38 | row[d_field2idx["dataset_name"]].strip() != dataset_name 39 | ): 40 | continue 41 | if dataset_res is not None and ( 42 | row[d_field2idx["dataset_res"]].strip() != dataset_res 43 | ): 44 | continue 45 | if dataset_split is not None and ( 46 | row[d_field2idx["dataset_split"]].strip() != dataset_split 47 | ): 48 | continue 49 | if task_name is not None and ( 50 | row[d_field2idx["task_name"]].strip() != task_name 51 | ): 52 | continue 53 | curr = {} 54 | for f in l_fields: 55 | curr[f.strip()] = row[d_field2idx[f.strip()]].strip() 56 | l_matches.append(curr) 57 | os.remove(local_path) 58 | return l_matches 59 | -------------------------------------------------------------------------------- /metrics/clean_fid/resize.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for resizing with multiple CPU cores 3 | """ 4 | import os 5 | import numpy as np 6 | import torch 7 | from PIL import Image 8 | import torch.nn.functional as F 9 | 10 | 11 | def build_resizer(mode): 12 | if mode == "clean": 13 | return make_resizer("PIL", False, "bicubic", (299, 299)) 14 | # if using legacy tensorflow, do not manually resize outside the network 15 | elif mode == "legacy_tensorflow": 16 | return lambda x: x 17 | elif mode == "legacy_pytorch": 18 | return make_resizer("PyTorch", False, "bilinear", (299, 299)) 19 | else: 20 | raise ValueError(f"Invalid mode {mode} specified") 21 | 22 | 23 | """ 24 | Construct a function that resizes a numpy image based on the 25 | flags passed in. 26 | """ 27 | 28 | 29 | def make_resizer(library, quantize_after, filter, output_size): 30 | if library == "PIL" and quantize_after: 31 | name_to_filter = { 32 | "bicubic": Image.BICUBIC, 33 | "bilinear": Image.BILINEAR, 34 | "nearest": Image.NEAREST, 35 | "lanczos": Image.LANCZOS, 36 | "box": Image.BOX, 37 | } 38 | 39 | def func(x): 40 | x = Image.fromarray(x) 41 | x = x.resize(output_size, resample=name_to_filter[filter]) 42 | x = np.asarray(x).clip(0, 255).astype(np.uint8) 43 | return x 44 | 45 | elif library == "PIL" and not quantize_after: 46 | name_to_filter = { 47 | "bicubic": Image.BICUBIC, 48 | "bilinear": Image.BILINEAR, 49 | "nearest": Image.NEAREST, 50 | "lanczos": Image.LANCZOS, 51 | "box": Image.BOX, 52 | } 53 | s1, s2 = output_size 54 | 55 | def resize_single_channel(x_np): 56 | img = Image.fromarray(x_np.astype(np.float32), mode="F") 57 | img = img.resize(output_size, resample=name_to_filter[filter]) 58 | return np.asarray(img).clip(0, 255).reshape(s2, s1, 1) 59 | 60 | def func(x): 61 | x = [resize_single_channel(x[:, :, idx]) for idx in range(3)] 62 | x = np.concatenate(x, axis=2).astype(np.float32) 63 | return x 64 | 65 | elif library == "PyTorch": 66 | import warnings 67 | 68 | # ignore the numpy warnings 69 | warnings.filterwarnings("ignore") 70 | 71 | def func(x): 72 | x = torch.Tensor(x.transpose((2, 0, 1)))[None, ...] 73 | x = F.interpolate(x, size=output_size, mode=filter, align_corners=False) 74 | x = x[0, ...].cpu().data.numpy().transpose((1, 2, 0)).clip(0, 255) 75 | if quantize_after: 76 | x = x.astype(np.uint8) 77 | return x 78 | 79 | else: 80 | raise NotImplementedError("library [%s] is not include" % library) 81 | return func 82 | 83 | 84 | class FolderResizer(torch.utils.data.Dataset): 85 | def __init__(self, files, outpath, fn_resize, output_ext=".png"): 86 | self.files = files 87 | self.outpath = outpath 88 | self.output_ext = output_ext 89 | self.fn_resize = fn_resize 90 | 91 | def __len__(self): 92 | return len(self.files) 93 | 94 | def __getitem__(self, i): 95 | path = str(self.files[i]) 96 | img_np = np.asarray(Image.open(path)) 97 | img_resize_np = self.fn_resize(img_np) 98 | # swap the output extension 99 | basename = os.path.basename(path).split(".")[0] + self.output_ext 100 | outname = os.path.join(self.outpath, basename) 101 | if self.output_ext == ".npy": 102 | np.save(outname, img_resize_np) 103 | elif self.output_ext == ".png": 104 | img_resized_pil = Image.fromarray(img_resize_np) 105 | img_resized_pil.save(outname) 106 | else: 107 | raise ValueError("invalid output extension") 108 | return 0 109 | -------------------------------------------------------------------------------- /metrics/clean_fid/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torchvision 4 | from PIL import Image 5 | import zipfile 6 | from .resize import build_resizer 7 | 8 | 9 | class ResizeDataset(torch.utils.data.Dataset): 10 | """ 11 | A placeholder Dataset that enables parallelizing the resize operation 12 | using multiple CPU cores 13 | 14 | files: list of all files in the folder 15 | fn_resize: function that takes an np_array as input [0,255] 16 | """ 17 | 18 | def __init__(self, files, mode, size=(299, 299), fdir=None): 19 | self.files = files 20 | self.fdir = fdir 21 | self.transforms = torchvision.transforms.ToTensor() 22 | self.size = size 23 | self.fn_resize = build_resizer(mode) 24 | self.custom_image_tranform = lambda x: x 25 | self._zipfile = None 26 | 27 | def _get_zipfile(self): 28 | assert self.fdir is not None and ".zip" in self.fdir 29 | if self._zipfile is None: 30 | self._zipfile = zipfile.ZipFile(self.fdir) 31 | return self._zipfile 32 | 33 | def __len__(self): 34 | return len(self.files) 35 | 36 | def __getitem__(self, i): 37 | path = str(self.files[i]) 38 | if self.fdir is not None and ".zip" in self.fdir: 39 | with self._get_zipfile().open(path, "r") as f: 40 | img_np = np.array(Image.open(f).convert("RGB")) 41 | elif ".npy" in path: 42 | img_np = np.load(path) 43 | else: 44 | img_pil = Image.open(path).convert("RGB") 45 | img_np = np.array(img_pil) 46 | 47 | # apply a custom image transform before resizing the image to 299x299 48 | img_np = self.custom_image_tranform(img_np) 49 | # fn_resize expects a np array and returns a np array 50 | img_resized = self.fn_resize(img_np) 51 | 52 | # ToTensor() converts to [0,1] only if input in uint8 53 | if img_resized.dtype == "uint8": 54 | img_t = self.transforms(np.array(img_resized)) * 255 55 | elif img_resized.dtype == "float32": 56 | img_t = self.transforms(img_resized) 57 | 58 | return img_t 59 | 60 | 61 | EXTENSIONS = { 62 | "bmp", 63 | "jpg", 64 | "jpeg", 65 | "pgm", 66 | "png", 67 | "ppm", 68 | "tif", 69 | "tiff", 70 | "webp", 71 | "npy", 72 | "JPEG", 73 | "JPG", 74 | "PNG", 75 | } 76 | -------------------------------------------------------------------------------- /metrics/clean_fid/wrappers.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | import torch 4 | from .features import build_feature_extractor, get_reference_statistics 5 | from .fid import get_batch_features, fid_from_feats 6 | from .resize import build_resizer 7 | 8 | 9 | """ 10 | A helper class that allowing adding the images one batch at a time. 11 | """ 12 | 13 | 14 | class CleanFID: 15 | def __init__(self, mode="clean", model_name="inception_v3", device="cuda"): 16 | self.real_features = [] 17 | self.gen_features = [] 18 | self.mode = mode 19 | self.device = device 20 | if model_name == "inception_v3": 21 | self.feat_model = build_feature_extractor(mode, device) 22 | self.fn_resize = build_resizer(mode) 23 | elif model_name == "clip_vit_b_32": 24 | from .clip_features import CLIP_fx, img_preprocess_clip 25 | 26 | clip_fx = CLIP_fx("ViT-B/32") 27 | self.feat_model = clip_fx 28 | self.fn_resize = img_preprocess_clip 29 | 30 | """ 31 | Funtion that takes an image (PIL.Image or np.array or torch.tensor) 32 | and returns the corresponding feature embedding vector. 33 | The image x is expected to be in range [0, 255] 34 | """ 35 | 36 | def compute_features(self, x): 37 | # if x is a PIL Image 38 | if isinstance(x, Image.Image): 39 | x_np = np.array(x) 40 | x_np_resized = self.fn_resize(x_np) 41 | x_t = torch.tensor(x_np_resized.transpose((2, 0, 1))).unsqueeze(0) 42 | x_feat = get_batch_features(x_t, self.feat_model, self.device) 43 | elif isinstance(x, np.ndarray): 44 | x_np_resized = self.fn_resize(x) 45 | x_t = ( 46 | torch.tensor(x_np_resized.transpose((2, 0, 1))) 47 | .unsqueeze(0) 48 | .to(self.device) 49 | ) 50 | # normalization happens inside the self.feat_model, expected image range here is [0,255] 51 | x_feat = get_batch_features(x_t, self.feat_model, self.device) 52 | elif isinstance(x, torch.Tensor): 53 | # pdb.set_trace() 54 | # add the batch dimension if x is passed in as C,H,W 55 | if len(x.shape) == 3: 56 | x = x.unsqueeze(0) 57 | b, c, h, w = x.shape 58 | # convert back to np array and resize 59 | l_x_np_resized = [] 60 | for _ in range(b): 61 | x_np = x[_].cpu().numpy().transpose((1, 2, 0)) 62 | l_x_np_resized.append(self.fn_resize(x_np)[None,]) 63 | x_np_resized = np.concatenate(l_x_np_resized) 64 | x_t = torch.tensor(x_np_resized.transpose((0, 3, 1, 2))).to(self.device) 65 | # normalization happens inside the self.feat_model, expected image range here is [0,255] 66 | x_feat = get_batch_features(x_t, self.feat_model, self.device) 67 | else: 68 | raise ValueError("image type could not be inferred") 69 | return x_feat 70 | 71 | """ 72 | Extract the faetures from x and add to the list of reference real images 73 | """ 74 | 75 | def add_real_images(self, x): 76 | x_feat = self.compute_features(x) 77 | self.real_features.append(x_feat) 78 | 79 | """ 80 | Extract the faetures from x and add to the list of generated images 81 | """ 82 | 83 | def add_gen_images(self, x): 84 | x_feat = self.compute_features(x) 85 | self.gen_features.append(x_feat) 86 | 87 | """ 88 | Compute FID between the real and generated images added so far 89 | """ 90 | 91 | def calculate_fid(self, verbose=True): 92 | feats1 = np.concatenate(self.real_features) 93 | feats2 = np.concatenate(self.gen_features) 94 | if verbose: 95 | print(f"# real images = {feats1.shape[0]}") 96 | print(f"# generated images = {feats2.shape[0]}") 97 | return fid_from_feats(feats1, feats2) 98 | 99 | """ 100 | Remove the real image features added so far 101 | """ 102 | 103 | def reset_real_features(self): 104 | self.real_features = [] 105 | 106 | """ 107 | Remove the generated image features added so far 108 | """ 109 | 110 | def reset_gen_features(self): 111 | self.gen_features = [] 112 | -------------------------------------------------------------------------------- /metrics/clip.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from PIL import Image 3 | import open_clip 4 | 5 | 6 | def load_open_clip_model_preprocess_and_tokenizer(device=torch.device("cuda")): 7 | clip_model, _, clip_preprocess = open_clip.create_model_and_transforms( 8 | "ViT-g-14", pretrained="laion2b_s12b_b42k", device=device 9 | ) 10 | clip_tokenizer = open_clip.get_tokenizer("ViT-g-14") 11 | return clip_model, clip_preprocess, clip_tokenizer 12 | 13 | 14 | def compute_clip_score( 15 | images, 16 | prompts, 17 | models, 18 | device=torch.device("cuda"), 19 | ): 20 | clip_model, clip_preprocess, clip_tokenizer = models 21 | with torch.no_grad(): 22 | tensors = [clip_preprocess(image) for image in images] 23 | image_processed_tensor = torch.stack(tensors, 0).to(device) 24 | image_features = clip_model.encode_image(image_processed_tensor) 25 | 26 | encoding = clip_tokenizer(prompts).to(device) 27 | text_features = clip_model.encode_text(encoding) 28 | 29 | image_features /= image_features.norm(dim=-1, keepdim=True) 30 | text_features /= text_features.norm(dim=-1, keepdim=True) 31 | 32 | return (image_features @ text_features.T).mean(-1).cpu().numpy().tolist() 33 | -------------------------------------------------------------------------------- /metrics/distributional.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | import torch 4 | from PIL import Image 5 | from tqdm.auto import tqdm 6 | from concurrent.futures import ProcessPoolExecutor 7 | from functools import partial 8 | from PIL import Image 9 | from .clean_fid import fid 10 | 11 | 12 | def save_single_image_to_temp(i, image, temp_dir): 13 | save_path = os.path.join(temp_dir, f"{i}.png") 14 | image.save(save_path, "PNG") 15 | 16 | 17 | def save_images_to_temp(images, num_workers, verbose=False): 18 | assert isinstance(images, list) and isinstance(images[0], Image.Image) 19 | temp_dir = tempfile.mkdtemp() 20 | 21 | # Using ProcessPoolExecutor to save images in parallel 22 | func = partial(save_single_image_to_temp, temp_dir=temp_dir) 23 | with ProcessPoolExecutor(max_workers=num_workers) as executor: 24 | tasks = executor.map(func, range(len(images)), images) 25 | list(tasks) if not verbose else list( 26 | tqdm( 27 | tasks, 28 | total=len(images), 29 | desc="Saving images ", 30 | ) 31 | ) 32 | return temp_dir 33 | 34 | 35 | # Compute FID between two sets of images 36 | def compute_fid( 37 | images1, 38 | images2, 39 | mode="legacy", 40 | device=None, 41 | batch_size=64, 42 | num_workers=None, 43 | verbose=False, 44 | ): 45 | # Support four types of FID scores 46 | assert mode in ["legacy", "clean", "clip"] 47 | if mode == "legacy": 48 | mode = "legacy_pytorch" 49 | model_name = "inception_v3" 50 | elif mode == "clean": 51 | mode = "clean" 52 | model_name = "inception_v3" 53 | elif mode == "clip": 54 | mode = "clean" 55 | model_name = "clip_vit_b_32" 56 | else: 57 | assert False 58 | 59 | # Set up device and num_workers 60 | if device is None: 61 | device = ( 62 | torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 63 | ) 64 | if num_workers is not None: 65 | assert 1 <= num_workers <= os.cpu_count() 66 | else: 67 | num_workers = max(torch.cuda.device_count() * 4, 8) 68 | 69 | # Check images, can be paths or lists of PIL images 70 | if not isinstance(images1, list): 71 | assert isinstance(images1, str) and os.path.exists(images1) 72 | assert isinstance(images2, str) and os.path.exists(images2) 73 | path1 = images1 74 | path2 = images2 75 | else: 76 | assert isinstance(images1, list) and isinstance(images1[0], Image.Image) 77 | assert isinstance(images2, list) and isinstance(images2[0], Image.Image) 78 | # Save images to temp dir if needed 79 | path1 = save_images_to_temp(images1, num_workers=num_workers, verbose=verbose) 80 | path2 = save_images_to_temp(images2, num_workers=num_workers, verbose=verbose) 81 | 82 | # Attempt to cache statistics for path1 83 | if not fid.test_stats_exists(name=str(os.path.abspath(path1)).replace("/", "_"), mode=mode, model_name=model_name): 84 | fid.make_custom_stats( 85 | name=str(os.path.abspath(path1)).replace("/", "_"), 86 | fdir=path1, 87 | mode=mode, 88 | model_name=model_name, 89 | device=device, 90 | num_workers=num_workers, 91 | verbose=verbose, 92 | ) 93 | fid_score = fid.compute_fid( 94 | path2, 95 | dataset_name=str(os.path.abspath(path1)).replace("/", "_"), 96 | dataset_split="custom", 97 | mode=mode, 98 | model_name=model_name, 99 | device=device, 100 | batch_size=batch_size, 101 | num_workers=num_workers, 102 | verbose=verbose, 103 | ) 104 | return fid_score 105 | -------------------------------------------------------------------------------- /metrics/image.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | from PIL import Image 5 | from skimage.metrics import ( 6 | mean_squared_error, 7 | peak_signal_noise_ratio, 8 | structural_similarity as structural_similarity_index_measure, 9 | normalized_mutual_information, 10 | ) 11 | from tqdm.auto import tqdm 12 | from concurrent.futures import ThreadPoolExecutor 13 | 14 | 15 | # Process images to numpy arrays 16 | def convert_image_pair_to_numpy(image1, image2): 17 | assert isinstance(image1, Image.Image) and isinstance(image2, Image.Image) 18 | 19 | image1_np = np.array(image1) 20 | image2_np = np.array(image2) 21 | assert image1_np.shape == image2_np.shape 22 | 23 | return image1_np, image2_np 24 | 25 | 26 | # Compute MSE between two images 27 | def compute_mse(image1, image2): 28 | image1_np, image2_np = convert_image_pair_to_numpy(image1, image2) 29 | return float(mean_squared_error(image1_np, image2_np)) 30 | 31 | 32 | # Compute PSNR between two images 33 | def compute_psnr(image1, image2): 34 | image1_np, image2_np = convert_image_pair_to_numpy(image1, image2) 35 | return float(peak_signal_noise_ratio(image1_np, image2_np)) 36 | 37 | 38 | # Compute SSIM between two images 39 | def compute_ssim(image1, image2): 40 | image1_np, image2_np = convert_image_pair_to_numpy(image1, image2) 41 | return float( 42 | structural_similarity_index_measure(image1_np, image2_np, channel_axis=2) 43 | ) 44 | 45 | 46 | # Compute NMI between two images 47 | def compute_nmi(image1, image2): 48 | image1_np, image2_np = convert_image_pair_to_numpy(image1, image2) 49 | return float(normalized_mutual_information(image1_np, image2_np)) 50 | 51 | 52 | # Compute metrics 53 | def compute_metric_repeated( 54 | images1, images2, metric_func, num_workers=None, verbose=False 55 | ): 56 | # Accept list of PIL images 57 | assert isinstance(images1, list) and isinstance(images1[0], Image.Image) 58 | assert isinstance(images2, list) and isinstance(images2[0], Image.Image) 59 | assert len(images1) == len(images2) 60 | 61 | if num_workers is not None: 62 | assert 1 <= num_workers <= os.cpu_count() 63 | else: 64 | num_workers = max(torch.cuda.device_count() * 4, 8) 65 | 66 | metric_name = metric_func.__name__.split("_")[1].upper() 67 | 68 | with ThreadPoolExecutor(max_workers=num_workers) as executor: 69 | tasks = executor.map(metric_func, images1, images2) 70 | values = ( 71 | list(tasks) 72 | if not verbose 73 | else list( 74 | tqdm( 75 | tasks, 76 | total=len(images1), 77 | desc=f"{metric_name} ", 78 | ) 79 | ) 80 | ) 81 | return values 82 | 83 | 84 | # Compute MSE between pairs of images 85 | def compute_mse_repeated(images1, images2, num_workers=None, verbose=False): 86 | return compute_metric_repeated(images1, images2, compute_mse, num_workers, verbose) 87 | 88 | 89 | # Compute PSNR between pairs of images 90 | def compute_psnr_repeated(images1, images2, num_workers=None, verbose=False): 91 | return compute_metric_repeated(images1, images2, compute_psnr, num_workers, verbose) 92 | 93 | 94 | # Compute SSIM between pairs of images 95 | def compute_ssim_repeated(images1, images2, num_workers=None, verbose=False): 96 | return compute_metric_repeated(images1, images2, compute_ssim, num_workers, verbose) 97 | 98 | 99 | # Compute NMI between pairs of images 100 | def compute_nmi_repeated(images1, images2, num_workers=None, verbose=False): 101 | return compute_metric_repeated(images1, images2, compute_nmi, num_workers, verbose) 102 | 103 | 104 | def compute_image_distance_repeated( 105 | images1, images2, metric_name, num_workers=None, verbose=False 106 | ): 107 | metric_func = { 108 | "psnr": compute_psnr, 109 | "ssim": compute_ssim, 110 | "nmi": compute_nmi, 111 | }[metric_name] 112 | return compute_metric_repeated(images1, images2, metric_func, num_workers, verbose) 113 | -------------------------------------------------------------------------------- /metrics/lpips/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | From https://github.com/richzhang/PerceptualSimilarity 3 | """ 4 | from .lpips import LPIPS 5 | -------------------------------------------------------------------------------- /metrics/lpips/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | import numpy as np 5 | import torch 6 | 7 | 8 | def normalize_tensor(in_feat, eps=1e-10): 9 | norm_factor = torch.sqrt(torch.sum(in_feat**2, dim=1, keepdim=True)) 10 | return in_feat / (norm_factor + eps) 11 | 12 | 13 | def l2(p0, p1, range=255.0): 14 | return 0.5 * np.mean((p0 / range - p1 / range) ** 2) 15 | 16 | 17 | def psnr(p0, p1, peak=255.0): 18 | return 10 * np.log10(peak**2 / np.mean((1.0 * p0 - 1.0 * p1) ** 2)) 19 | 20 | 21 | def dssim(p0, p1, range=255.0): 22 | from skimage.measure import compare_ssim 23 | 24 | return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2.0 25 | 26 | 27 | def tensor2np(tensor_obj): 28 | # change dimension of a tensor object into a numpy array 29 | return tensor_obj[0].cpu().float().numpy().transpose((1, 2, 0)) 30 | 31 | 32 | def np2tensor(np_obj): 33 | # change dimenion of np array into tensor array 34 | return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 35 | 36 | 37 | def tensor2tensorlab(image_tensor, to_norm=True, mc_only=False): 38 | # image tensor to lab tensor 39 | from skimage import color 40 | 41 | img = tensor2im(image_tensor) 42 | img_lab = color.rgb2lab(img) 43 | if mc_only: 44 | img_lab[:, :, 0] = img_lab[:, :, 0] - 50 45 | if to_norm and not mc_only: 46 | img_lab[:, :, 0] = img_lab[:, :, 0] - 50 47 | img_lab = img_lab / 100.0 48 | 49 | return np2tensor(img_lab) 50 | 51 | 52 | def tensorlab2tensor(lab_tensor, return_inbnd=False): 53 | from skimage import color 54 | import warnings 55 | 56 | warnings.filterwarnings("ignore") 57 | 58 | lab = tensor2np(lab_tensor) * 100.0 59 | lab[:, :, 0] = lab[:, :, 0] + 50 60 | 61 | rgb_back = 255.0 * np.clip(color.lab2rgb(lab.astype("float")), 0, 1) 62 | if return_inbnd: 63 | # convert back to lab, see if we match 64 | lab_back = color.rgb2lab(rgb_back.astype("uint8")) 65 | mask = 1.0 * np.isclose(lab_back, lab, atol=2.0) 66 | mask = np2tensor(np.prod(mask, axis=2)[:, :, np.newaxis]) 67 | return (im2tensor(rgb_back), mask) 68 | else: 69 | return im2tensor(rgb_back) 70 | 71 | 72 | def load_image(path): 73 | if ( 74 | path[-3:] == "bmp" 75 | or path[-3:] == "jpg" 76 | or path[-3:] == "png" 77 | or path[-4:] == "jpeg" 78 | ): 79 | import cv2 80 | 81 | return cv2.imread(path)[:, :, ::-1] 82 | else: 83 | import matplotlib.pyplot as plt 84 | 85 | img = (255 * plt.imread(path)[:, :, :3]).astype("uint8") 86 | 87 | return img 88 | 89 | 90 | def tensor2im(image_tensor, imtype=np.uint8, cent=1.0, factor=255.0 / 2.0): 91 | image_numpy = image_tensor[0].cpu().float().numpy() 92 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor 93 | return image_numpy.astype(imtype) 94 | 95 | 96 | def im2tensor(image, imtype=np.uint8, cent=1.0, factor=255.0 / 2.0): 97 | return torch.Tensor( 98 | (image / factor - cent)[:, :, :, np.newaxis].transpose((3, 2, 0, 1)) 99 | ) 100 | 101 | 102 | def tensor2vec(vector_tensor): 103 | return vector_tensor.data.cpu().numpy()[:, :, 0, 0] 104 | 105 | 106 | def voc_ap(rec, prec, use_07_metric=False): 107 | """ap = voc_ap(rec, prec, [use_07_metric]) 108 | Compute VOC AP given precision and recall. 109 | If use_07_metric is true, uses the 110 | VOC 07 11 point method (default:False). 111 | """ 112 | if use_07_metric: 113 | # 11 point metric 114 | ap = 0.0 115 | for t in np.arange(0.0, 1.1, 0.1): 116 | if np.sum(rec >= t) == 0: 117 | p = 0 118 | else: 119 | p = np.max(prec[rec >= t]) 120 | ap = ap + p / 11.0 121 | else: 122 | # correct AP calculation 123 | # first append sentinel values at the end 124 | mrec = np.concatenate(([0.0], rec, [1.0])) 125 | mpre = np.concatenate(([0.0], prec, [0.0])) 126 | 127 | # compute the precision envelope 128 | for i in range(mpre.size - 1, 0, -1): 129 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) 130 | 131 | # to calculate area under PR curve, look for points 132 | # where X axis (recall) changes value 133 | i = np.where(mrec[1:] != mrec[:-1])[0] 134 | 135 | # and sum (\Delta recall) * prec 136 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) 137 | return ap 138 | -------------------------------------------------------------------------------- /metrics/lpips/weights/v0.0/alex.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:18720f55913d0af89042f13faa7e536a6ce1444a0914e6db9461355ece1e8cd5 3 | size 5455 4 | -------------------------------------------------------------------------------- /metrics/lpips/weights/v0.0/squeeze.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:c27abd3a0145541baa50990817df58d3759c3f8154949f42af3b59b4e042d0bf 3 | size 10057 4 | -------------------------------------------------------------------------------- /metrics/lpips/weights/v0.0/vgg.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:b9e4236260c3dd988fc79d2a48d645d885afcbb21f9fd595e6744cf7419b582c 3 | size 6735 4 | -------------------------------------------------------------------------------- /metrics/lpips/weights/v0.1/alex.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:df73285e35b22355a2df87cdb6b70b343713b667eddbda73e1977e0c860835c0 3 | size 6009 4 | -------------------------------------------------------------------------------- /metrics/lpips/weights/v0.1/squeeze.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:4a5350f23600cb79923ce65bb07cbf57dca461329894153e05a1346bd531cf76 3 | size 10811 4 | -------------------------------------------------------------------------------- /metrics/lpips/weights/v0.1/vgg.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:a78928a0af1e5f0fcb1f3b9e8f8c3a2a5a3de244d830ad5c1feddc79b8432868 3 | size 7289 4 | -------------------------------------------------------------------------------- /metrics/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .distributional import compute_fid 2 | from .image import ( 3 | compute_mse, 4 | compute_psnr, 5 | compute_ssim, 6 | compute_nmi, 7 | compute_mse_repeated, 8 | compute_psnr_repeated, 9 | compute_ssim_repeated, 10 | compute_nmi_repeated, 11 | compute_image_distance_repeated, 12 | ) 13 | from .perceptual import ( 14 | load_perceptual_models, 15 | compute_lpips, 16 | compute_watson, 17 | compute_lpips_repeated, 18 | compute_watson_repeated, 19 | compute_perceptual_metric_repeated, 20 | ) 21 | from .aesthetics import ( 22 | load_aesthetics_and_artifacts_models, 23 | compute_aesthetics_and_artifacts_scores, 24 | ) 25 | from .clip import load_open_clip_model_preprocess_and_tokenizer, compute_clip_score 26 | from .prompt import ( 27 | load_perplexity_model_and_tokenizer, 28 | compute_prompt_perplexity, 29 | ) 30 | -------------------------------------------------------------------------------- /metrics/metrics/aesthetics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from PIL import Image 3 | from transformers import CLIPModel, CLIPProcessor 4 | from .aesthetics_scorer import preprocess, load_model 5 | 6 | 7 | def load_aesthetics_and_artifacts_models(device=torch.device("cuda")): 8 | model = CLIPModel.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") 9 | vision_model = model.vision_model 10 | vision_model.to(device) 11 | del model 12 | clip_processor = CLIPProcessor.from_pretrained( 13 | "laion/CLIP-ViT-H-14-laion2B-s32B-b79K" 14 | ) 15 | rating_model = load_model("aesthetics_scorer_rating_openclip_vit_h_14").to(device) 16 | artifacts_model = load_model("aesthetics_scorer_artifacts_openclip_vit_h_14").to( 17 | device 18 | ) 19 | return vision_model, clip_processor, rating_model, artifacts_model 20 | 21 | 22 | def compute_aesthetics_and_artifacts_scores( 23 | images, models, device=torch.device("cuda") 24 | ): 25 | vision_model, clip_processor, rating_model, artifacts_model = models 26 | 27 | inputs = clip_processor(images=images, return_tensors="pt").to(device) 28 | with torch.no_grad(): 29 | vision_output = vision_model(**inputs) 30 | pooled_output = vision_output.pooler_output 31 | embedding = preprocess(pooled_output) 32 | with torch.no_grad(): 33 | rating = rating_model(embedding) 34 | artifact = artifacts_model(embedding) 35 | return ( 36 | rating.detach().cpu().numpy().flatten().tolist(), 37 | artifact.detach().cpu().numpy().flatten().tolist(), 38 | ) 39 | -------------------------------------------------------------------------------- /metrics/metrics/aesthetics_scorer/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | From https://github.com/kenjiqq/aesthetics-scorer#validation-split-of-diffusiondb-dataset 3 | """ 4 | from .model import preprocess, load_model 5 | -------------------------------------------------------------------------------- /metrics/metrics/aesthetics_scorer/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import json 4 | import os 5 | import inspect 6 | 7 | 8 | class AestheticScorer(nn.Module): 9 | def __init__( 10 | self, 11 | input_size=0, 12 | use_activation=False, 13 | dropout=0.2, 14 | config=None, 15 | hidden_dim=1024, 16 | reduce_dims=False, 17 | output_activation=None, 18 | ): 19 | super().__init__() 20 | self.config = { 21 | "input_size": input_size, 22 | "use_activation": use_activation, 23 | "dropout": dropout, 24 | "hidden_dim": hidden_dim, 25 | "reduce_dims": reduce_dims, 26 | "output_activation": output_activation, 27 | } 28 | if config != None: 29 | self.config.update(config) 30 | 31 | layers = [ 32 | nn.Linear(self.config["input_size"], self.config["hidden_dim"]), 33 | nn.ReLU() if self.config["use_activation"] else None, 34 | nn.Dropout(self.config["dropout"]), 35 | nn.Linear( 36 | self.config["hidden_dim"], 37 | round(self.config["hidden_dim"] / (2 if reduce_dims else 1)), 38 | ), 39 | nn.ReLU() if self.config["use_activation"] else None, 40 | nn.Dropout(self.config["dropout"]), 41 | nn.Linear( 42 | round(self.config["hidden_dim"] / (2 if reduce_dims else 1)), 43 | round(self.config["hidden_dim"] / (4 if reduce_dims else 1)), 44 | ), 45 | nn.ReLU() if self.config["use_activation"] else None, 46 | nn.Dropout(self.config["dropout"]), 47 | nn.Linear( 48 | round(self.config["hidden_dim"] / (4 if reduce_dims else 1)), 49 | round(self.config["hidden_dim"] / (8 if reduce_dims else 1)), 50 | ), 51 | nn.ReLU() if self.config["use_activation"] else None, 52 | nn.Linear(round(self.config["hidden_dim"] / (8 if reduce_dims else 1)), 1), 53 | ] 54 | if self.config["output_activation"] == "sigmoid": 55 | layers.append(nn.Sigmoid()) 56 | layers = [x for x in layers if x is not None] 57 | self.layers = nn.Sequential(*layers) 58 | 59 | def forward(self, x): 60 | if self.config["output_activation"] == "sigmoid": 61 | upper, lower = 10, 1 62 | scale = upper - lower 63 | return (self.layers(x) * scale) + lower 64 | else: 65 | return self.layers(x) 66 | 67 | def save(self, save_name): 68 | split_name = os.path.splitext(save_name) 69 | with open(f"{split_name[0]}.config", "w") as outfile: 70 | outfile.write(json.dumps(self.config, indent=4)) 71 | 72 | for i in range( 73 | 6 74 | ): # saving sometiles fails, so retry 5 times, might be windows issue 75 | try: 76 | torch.save(self.state_dict(), save_name) 77 | break 78 | except RuntimeError as e: 79 | # check if error contains string "File" 80 | if "cannot be opened" in str(e) and i < 5: 81 | print("Model save failed, retrying...") 82 | else: 83 | raise e 84 | 85 | 86 | def preprocess(embeddings): 87 | return embeddings / embeddings.norm(p=2, dim=-1, keepdim=True) 88 | 89 | 90 | def load_model(weight_name, device="cuda" if torch.cuda.is_available() else "cpu"): 91 | weight_folder = os.path.abspath( 92 | os.path.join( 93 | inspect.getfile(load_model), 94 | "../weights", 95 | ) 96 | ) 97 | weight_path = os.path.join(weight_folder, f"{weight_name}.pth") 98 | config_path = os.path.join(weight_folder, f"{weight_name}.config") 99 | with open(config_path, "r") as config_file: 100 | config = json.load(config_file) 101 | model = AestheticScorer(config=config) 102 | model.load_state_dict(torch.load(weight_path, map_location=device)) 103 | model.eval() 104 | return model 105 | -------------------------------------------------------------------------------- /metrics/metrics/aesthetics_scorer/weights/aesthetics_scorer_artifacts_openclip_vit_bigg_14.config: -------------------------------------------------------------------------------- 1 | { 2 | "input_size": 1664, 3 | "use_activation": false, 4 | "dropout": 0.0, 5 | "hidden_dim": 1024, 6 | "reduce_dims": false, 7 | "output_activation": null 8 | } -------------------------------------------------------------------------------- /metrics/metrics/aesthetics_scorer/weights/aesthetics_scorer_artifacts_openclip_vit_bigg_14.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:39a5d014670226d52c408e0dfec840b7626d80a73d003a6a144caafd5e02d031 3 | size 19423219 4 | -------------------------------------------------------------------------------- /metrics/metrics/aesthetics_scorer/weights/aesthetics_scorer_artifacts_openclip_vit_h_14.config: -------------------------------------------------------------------------------- 1 | { 2 | "input_size": 1280, 3 | "use_activation": false, 4 | "dropout": 0.0, 5 | "hidden_dim": 1024, 6 | "reduce_dims": false, 7 | "output_activation": null 8 | } -------------------------------------------------------------------------------- /metrics/metrics/aesthetics_scorer/weights/aesthetics_scorer_artifacts_openclip_vit_h_14.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:dc48a8a2315cfdbc7bb8278be55f645e8a995e1a2fa234baec5eb41c4d33e070 3 | size 17850319 4 | -------------------------------------------------------------------------------- /metrics/metrics/aesthetics_scorer/weights/aesthetics_scorer_artifacts_openclip_vit_l_14.config: -------------------------------------------------------------------------------- 1 | { 2 | "input_size": 1024, 3 | "use_activation": false, 4 | "dropout": 0.0, 5 | "hidden_dim": 1024, 6 | "reduce_dims": false, 7 | "output_activation": null 8 | } -------------------------------------------------------------------------------- /metrics/metrics/aesthetics_scorer/weights/aesthetics_scorer_artifacts_openclip_vit_l_14.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:c4a9481fdbce5ff02b252bcb25109b9f3b29841289fadf7e79e884d59f9357d5 3 | size 16801743 4 | -------------------------------------------------------------------------------- /metrics/metrics/aesthetics_scorer/weights/aesthetics_scorer_rating_openclip_vit_bigg_14.config: -------------------------------------------------------------------------------- 1 | { 2 | "input_size": 1664, 3 | "use_activation": false, 4 | "dropout": 0.0, 5 | "hidden_dim": 1024, 6 | "reduce_dims": false, 7 | "output_activation": null 8 | } -------------------------------------------------------------------------------- /metrics/metrics/aesthetics_scorer/weights/aesthetics_scorer_rating_openclip_vit_bigg_14.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:19b016304f54ae866e27f1eb498c0861f704958e7c37693adc5ce094e63904a8 3 | size 19423099 4 | -------------------------------------------------------------------------------- /metrics/metrics/aesthetics_scorer/weights/aesthetics_scorer_rating_openclip_vit_h_14.config: -------------------------------------------------------------------------------- 1 | { 2 | "input_size": 1280, 3 | "use_activation": false, 4 | "dropout": 0.0, 5 | "hidden_dim": 1024, 6 | "reduce_dims": false, 7 | "output_activation": null 8 | } -------------------------------------------------------------------------------- /metrics/metrics/aesthetics_scorer/weights/aesthetics_scorer_rating_openclip_vit_h_14.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:03603eee1864c2e5e97ef7079229609653db5b10594ca8b1de9e541d838cae9c 3 | size 17850199 4 | -------------------------------------------------------------------------------- /metrics/metrics/aesthetics_scorer/weights/aesthetics_scorer_rating_openclip_vit_l_14.config: -------------------------------------------------------------------------------- 1 | { 2 | "input_size": 1024, 3 | "use_activation": false, 4 | "dropout": 0.0, 5 | "hidden_dim": 1024, 6 | "reduce_dims": false, 7 | "output_activation": null 8 | } -------------------------------------------------------------------------------- /metrics/metrics/aesthetics_scorer/weights/aesthetics_scorer_rating_openclip_vit_l_14.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:eb7fe561369ab6c7dad34b9316a56d2c6070582f0323656148e1107a242cd666 3 | size 16801623 4 | -------------------------------------------------------------------------------- /metrics/metrics/clean_fid/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | From https://github.com/GaParmar/clean-fid/tree/main 3 | """ 4 | -------------------------------------------------------------------------------- /metrics/metrics/clean_fid/clip_features.py: -------------------------------------------------------------------------------- 1 | # pip install git+https://github.com/openai/CLIP.git 2 | import pdb 3 | from PIL import Image 4 | import numpy as np 5 | import torch 6 | import torchvision.transforms as transforms 7 | import clip 8 | from .fid import compute_fid 9 | 10 | 11 | def img_preprocess_clip(img_np): 12 | x = Image.fromarray(img_np.astype(np.uint8)).convert("RGB") 13 | T = transforms.Compose( 14 | [ 15 | transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC), 16 | transforms.CenterCrop(224), 17 | ] 18 | ) 19 | return np.asarray(T(x)).clip(0, 255).astype(np.uint8) 20 | 21 | 22 | class CLIP_fx: 23 | def __init__(self, name="ViT-B/32", device="cuda"): 24 | self.model, _ = clip.load(name, device=device) 25 | self.model.eval() 26 | self.name = "clip_" + name.lower().replace("-", "_").replace("/", "_") 27 | 28 | def __call__(self, img_t): 29 | img_x = img_t / 255.0 30 | T_norm = transforms.Normalize( 31 | (0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711) 32 | ) 33 | img_x = T_norm(img_x) 34 | assert torch.is_tensor(img_x) 35 | if len(img_x.shape) == 3: 36 | img_x = img_x.unsqueeze(0) 37 | B, C, H, W = img_x.shape 38 | with torch.no_grad(): 39 | z = self.model.encode_image(img_x) 40 | return z 41 | -------------------------------------------------------------------------------- /metrics/metrics/clean_fid/downloads_helper.py: -------------------------------------------------------------------------------- 1 | import os 2 | import urllib.request 3 | import requests 4 | import shutil 5 | 6 | 7 | inception_url = "https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt" 8 | 9 | 10 | """ 11 | Download the pretrined inception weights if it does not exists 12 | ARGS: 13 | fpath - output folder path 14 | """ 15 | 16 | 17 | def check_download_inception(fpath="./"): 18 | inception_path = os.path.join(fpath, "inception-2015-12-05.pt") 19 | if not os.path.exists(inception_path): 20 | # download the file 21 | with urllib.request.urlopen(inception_url) as response, open( 22 | inception_path, "wb" 23 | ) as f: 24 | shutil.copyfileobj(response, f) 25 | return inception_path 26 | 27 | 28 | """ 29 | Download any url if it does not exist 30 | ARGS: 31 | local_folder - output folder path 32 | url - the weburl to download 33 | """ 34 | 35 | 36 | def check_download_url(local_folder, url): 37 | name = os.path.basename(url) 38 | local_path = os.path.join(local_folder, name) 39 | if not os.path.exists(local_path): 40 | os.makedirs(local_folder, exist_ok=True) 41 | print(f"downloading statistics to {local_path}") 42 | with urllib.request.urlopen(url) as response, open(local_path, "wb") as f: 43 | shutil.copyfileobj(response, f) 44 | return local_path 45 | 46 | 47 | """ 48 | Download a file from google drive 49 | ARGS: 50 | file_id - id of the google drive file 51 | out_path - output folder path 52 | """ 53 | 54 | 55 | def download_google_drive(file_id, out_path): 56 | def get_confirm_token(response): 57 | for key, value in response.cookies.items(): 58 | if key.startswith("download_warning"): 59 | return value 60 | return None 61 | 62 | URL = "https://drive.google.com/uc?export=download" 63 | session = requests.Session() 64 | response = session.get(URL, params={"id": file_id}, stream=True) 65 | token = get_confirm_token(response) 66 | 67 | if token: 68 | params = {"id": file_id, "confirm": token} 69 | response = session.get(URL, params=params, stream=True) 70 | 71 | CHUNK_SIZE = 32768 72 | with open(out_path, "wb") as f: 73 | for chunk in response.iter_content(CHUNK_SIZE): 74 | if chunk: 75 | f.write(chunk) 76 | -------------------------------------------------------------------------------- /metrics/metrics/clean_fid/features.py: -------------------------------------------------------------------------------- 1 | """ 2 | helpers for extracting features from image 3 | """ 4 | import os 5 | import platform 6 | import numpy as np 7 | import torch 8 | from torch.hub import get_dir 9 | from .downloads_helper import check_download_url 10 | from .inception_pytorch import InceptionV3 11 | from .inception_torchscript import InceptionV3W 12 | 13 | 14 | """ 15 | returns a functions that takes an image in range [0,255] 16 | and outputs a feature embedding vector 17 | """ 18 | 19 | 20 | def feature_extractor( 21 | name="torchscript_inception", 22 | device=torch.device("cuda"), 23 | resize_inside=False, 24 | use_dataparallel=True, 25 | ): 26 | if name == "torchscript_inception": 27 | path = "./" if platform.system() == "Windows" else "/tmp" 28 | model = InceptionV3W(path, download=True, resize_inside=resize_inside).to( 29 | device 30 | ) 31 | model.eval() 32 | if use_dataparallel: 33 | model = torch.nn.DataParallel(model) 34 | 35 | def model_fn(x): 36 | return model(x) 37 | 38 | elif name == "pytorch_inception": 39 | model = InceptionV3(output_blocks=[3], resize_input=False).to(device) 40 | model.eval() 41 | if use_dataparallel: 42 | model = torch.nn.DataParallel(model) 43 | 44 | def model_fn(x): 45 | return model(x / 255)[0].squeeze(-1).squeeze(-1) 46 | 47 | else: 48 | raise ValueError(f"{name} feature extractor not implemented") 49 | return model_fn 50 | 51 | 52 | """ 53 | Build a feature extractor for each of the modes 54 | """ 55 | 56 | 57 | def build_feature_extractor(mode, device=torch.device("cuda"), use_dataparallel=True): 58 | if mode == "legacy_pytorch": 59 | feat_model = feature_extractor( 60 | name="pytorch_inception", 61 | resize_inside=False, 62 | device=device, 63 | use_dataparallel=use_dataparallel, 64 | ) 65 | elif mode == "legacy_tensorflow": 66 | feat_model = feature_extractor( 67 | name="torchscript_inception", 68 | resize_inside=True, 69 | device=device, 70 | use_dataparallel=use_dataparallel, 71 | ) 72 | elif mode == "clean": 73 | feat_model = feature_extractor( 74 | name="torchscript_inception", 75 | resize_inside=False, 76 | device=device, 77 | use_dataparallel=use_dataparallel, 78 | ) 79 | return feat_model 80 | 81 | 82 | """ 83 | Load precomputed reference statistics for commonly used datasets 84 | """ 85 | 86 | 87 | def get_reference_statistics( 88 | name, 89 | res, 90 | mode="clean", 91 | model_name="inception_v3", 92 | seed=0, 93 | split="test", 94 | metric="FID", 95 | ): 96 | base_url = "https://www.cs.cmu.edu/~clean-fid/stats/" 97 | if split == "custom": 98 | res = "na" 99 | if model_name == "inception_v3": 100 | model_modifier = "" 101 | else: 102 | model_modifier = "_" + model_name 103 | if metric == "FID": 104 | rel_path = (f"{name}_{mode}{model_modifier}_{split}_{res}.npz") 105 | url = f"{base_url}/{rel_path}" 106 | stats_folder = os.path.join(get_dir(), "fid_stats") 107 | fpath = check_download_url(local_folder=stats_folder, url=url) 108 | stats = np.load(fpath) 109 | mu, sigma = stats["mu"], stats["sigma"] 110 | return mu, sigma 111 | elif metric == "KID": 112 | rel_path = (f"{name}_{mode}{model_modifier}_{split}_{res}_kid.npz") 113 | url = f"{base_url}/{rel_path}" 114 | stats_folder = os.path.join(get_dir(), "fid_stats") 115 | fpath = check_download_url(local_folder=stats_folder, url=url) 116 | stats = np.load(fpath) 117 | return stats["feats"] 118 | -------------------------------------------------------------------------------- /metrics/metrics/clean_fid/inception_torchscript.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import contextlib 5 | from .downloads_helper import * 6 | 7 | 8 | @contextlib.contextmanager 9 | def disable_gpu_fuser_on_pt19(): 10 | # On PyTorch 1.9 a CUDA fuser bug prevents the Inception JIT model to run. See 11 | # https://github.com/GaParmar/clean-fid/issues/5 12 | # https://github.com/pytorch/pytorch/issues/64062 13 | if torch.__version__.startswith("1.9."): 14 | old_val = torch._C._jit_can_fuse_on_gpu() 15 | torch._C._jit_override_can_fuse_on_gpu(False) 16 | yield 17 | if torch.__version__.startswith("1.9."): 18 | torch._C._jit_override_can_fuse_on_gpu(old_val) 19 | 20 | 21 | class InceptionV3W(nn.Module): 22 | """ 23 | Wrapper around Inception V3 torchscript model provided here 24 | https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt 25 | 26 | path: locally saved inception weights 27 | """ 28 | 29 | def __init__(self, path, download=True, resize_inside=False): 30 | super(InceptionV3W, self).__init__() 31 | # download the network if it is not present at the given directory 32 | # use the current directory by default 33 | if download: 34 | check_download_inception(fpath=path) 35 | path = os.path.join(path, "inception-2015-12-05.pt") 36 | self.base = torch.jit.load(path).eval() 37 | self.layers = self.base.layers 38 | self.resize_inside = resize_inside 39 | 40 | """ 41 | Get the inception features without resizing 42 | x: Image with values in range [0,255] 43 | """ 44 | 45 | def forward(self, x): 46 | with disable_gpu_fuser_on_pt19(): 47 | bs = x.shape[0] 48 | if self.resize_inside: 49 | features = self.base(x, return_features=True).view((bs, 2048)) 50 | else: 51 | # make sure it is resized already 52 | assert (x.shape[2] == 299) and (x.shape[3] == 299) 53 | # apply normalization 54 | x1 = x - 128 55 | x2 = x1 / 128 56 | features = self.layers.forward( 57 | x2, 58 | ).view((bs, 2048)) 59 | return features 60 | -------------------------------------------------------------------------------- /metrics/metrics/clean_fid/leaderboard.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | import shutil 4 | import urllib.request 5 | 6 | 7 | def get_score( 8 | model_name=None, 9 | dataset_name=None, 10 | dataset_res=None, 11 | dataset_split=None, 12 | task_name=None, 13 | ): 14 | # download the csv file from server 15 | url = "https://www.cs.cmu.edu/~clean-fid/files/leaderboard.csv" 16 | local_path = "/tmp/leaderboard.csv" 17 | with urllib.request.urlopen(url) as response, open(local_path, "wb") as f: 18 | shutil.copyfileobj(response, f) 19 | 20 | d_field2idx = {} 21 | l_matches = [] 22 | with open(local_path, "r") as f: 23 | csvreader = csv.reader(f) 24 | l_fields = next(csvreader) 25 | for idx, val in enumerate(l_fields): 26 | d_field2idx[val.strip()] = idx 27 | # iterate through all rows 28 | for row in csvreader: 29 | # skip empty rows 30 | if len(row) == 0: 31 | continue 32 | # skip if the filter doesn't match 33 | if model_name is not None and ( 34 | row[d_field2idx["model_name"]].strip() != model_name 35 | ): 36 | continue 37 | if dataset_name is not None and ( 38 | row[d_field2idx["dataset_name"]].strip() != dataset_name 39 | ): 40 | continue 41 | if dataset_res is not None and ( 42 | row[d_field2idx["dataset_res"]].strip() != dataset_res 43 | ): 44 | continue 45 | if dataset_split is not None and ( 46 | row[d_field2idx["dataset_split"]].strip() != dataset_split 47 | ): 48 | continue 49 | if task_name is not None and ( 50 | row[d_field2idx["task_name"]].strip() != task_name 51 | ): 52 | continue 53 | curr = {} 54 | for f in l_fields: 55 | curr[f.strip()] = row[d_field2idx[f.strip()]].strip() 56 | l_matches.append(curr) 57 | os.remove(local_path) 58 | return l_matches 59 | -------------------------------------------------------------------------------- /metrics/metrics/clean_fid/resize.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for resizing with multiple CPU cores 3 | """ 4 | import os 5 | import numpy as np 6 | import torch 7 | from PIL import Image 8 | import torch.nn.functional as F 9 | 10 | 11 | def build_resizer(mode): 12 | if mode == "clean": 13 | return make_resizer("PIL", False, "bicubic", (299, 299)) 14 | # if using legacy tensorflow, do not manually resize outside the network 15 | elif mode == "legacy_tensorflow": 16 | return lambda x: x 17 | elif mode == "legacy_pytorch": 18 | return make_resizer("PyTorch", False, "bilinear", (299, 299)) 19 | else: 20 | raise ValueError(f"Invalid mode {mode} specified") 21 | 22 | 23 | """ 24 | Construct a function that resizes a numpy image based on the 25 | flags passed in. 26 | """ 27 | 28 | 29 | def make_resizer(library, quantize_after, filter, output_size): 30 | if library == "PIL" and quantize_after: 31 | name_to_filter = { 32 | "bicubic": Image.BICUBIC, 33 | "bilinear": Image.BILINEAR, 34 | "nearest": Image.NEAREST, 35 | "lanczos": Image.LANCZOS, 36 | "box": Image.BOX, 37 | } 38 | 39 | def func(x): 40 | x = Image.fromarray(x) 41 | x = x.resize(output_size, resample=name_to_filter[filter]) 42 | x = np.asarray(x).clip(0, 255).astype(np.uint8) 43 | return x 44 | 45 | elif library == "PIL" and not quantize_after: 46 | name_to_filter = { 47 | "bicubic": Image.BICUBIC, 48 | "bilinear": Image.BILINEAR, 49 | "nearest": Image.NEAREST, 50 | "lanczos": Image.LANCZOS, 51 | "box": Image.BOX, 52 | } 53 | s1, s2 = output_size 54 | 55 | def resize_single_channel(x_np): 56 | img = Image.fromarray(x_np.astype(np.float32), mode="F") 57 | img = img.resize(output_size, resample=name_to_filter[filter]) 58 | return np.asarray(img).clip(0, 255).reshape(s2, s1, 1) 59 | 60 | def func(x): 61 | x = [resize_single_channel(x[:, :, idx]) for idx in range(3)] 62 | x = np.concatenate(x, axis=2).astype(np.float32) 63 | return x 64 | 65 | elif library == "PyTorch": 66 | import warnings 67 | 68 | # ignore the numpy warnings 69 | warnings.filterwarnings("ignore") 70 | 71 | def func(x): 72 | x = torch.Tensor(x.transpose((2, 0, 1)))[None, ...] 73 | x = F.interpolate(x, size=output_size, mode=filter, align_corners=False) 74 | x = x[0, ...].cpu().data.numpy().transpose((1, 2, 0)).clip(0, 255) 75 | if quantize_after: 76 | x = x.astype(np.uint8) 77 | return x 78 | 79 | else: 80 | raise NotImplementedError("library [%s] is not include" % library) 81 | return func 82 | 83 | 84 | class FolderResizer(torch.utils.data.Dataset): 85 | def __init__(self, files, outpath, fn_resize, output_ext=".png"): 86 | self.files = files 87 | self.outpath = outpath 88 | self.output_ext = output_ext 89 | self.fn_resize = fn_resize 90 | 91 | def __len__(self): 92 | return len(self.files) 93 | 94 | def __getitem__(self, i): 95 | path = str(self.files[i]) 96 | img_np = np.asarray(Image.open(path)) 97 | img_resize_np = self.fn_resize(img_np) 98 | # swap the output extension 99 | basename = os.path.basename(path).split(".")[0] + self.output_ext 100 | outname = os.path.join(self.outpath, basename) 101 | if self.output_ext == ".npy": 102 | np.save(outname, img_resize_np) 103 | elif self.output_ext == ".png": 104 | img_resized_pil = Image.fromarray(img_resize_np) 105 | img_resized_pil.save(outname) 106 | else: 107 | raise ValueError("invalid output extension") 108 | return 0 109 | -------------------------------------------------------------------------------- /metrics/metrics/clean_fid/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torchvision 4 | from PIL import Image 5 | import zipfile 6 | from .resize import build_resizer 7 | 8 | 9 | class ResizeDataset(torch.utils.data.Dataset): 10 | """ 11 | A placeholder Dataset that enables parallelizing the resize operation 12 | using multiple CPU cores 13 | 14 | files: list of all files in the folder 15 | fn_resize: function that takes an np_array as input [0,255] 16 | """ 17 | 18 | def __init__(self, files, mode, size=(299, 299), fdir=None): 19 | self.files = files 20 | self.fdir = fdir 21 | self.transforms = torchvision.transforms.ToTensor() 22 | self.size = size 23 | self.fn_resize = build_resizer(mode) 24 | self.custom_image_tranform = lambda x: x 25 | self._zipfile = None 26 | 27 | def _get_zipfile(self): 28 | assert self.fdir is not None and ".zip" in self.fdir 29 | if self._zipfile is None: 30 | self._zipfile = zipfile.ZipFile(self.fdir) 31 | return self._zipfile 32 | 33 | def __len__(self): 34 | return len(self.files) 35 | 36 | def __getitem__(self, i): 37 | path = str(self.files[i]) 38 | if self.fdir is not None and ".zip" in self.fdir: 39 | with self._get_zipfile().open(path, "r") as f: 40 | img_np = np.array(Image.open(f).convert("RGB")) 41 | elif ".npy" in path: 42 | img_np = np.load(path) 43 | else: 44 | img_pil = Image.open(path).convert("RGB") 45 | img_np = np.array(img_pil) 46 | 47 | # apply a custom image transform before resizing the image to 299x299 48 | img_np = self.custom_image_tranform(img_np) 49 | # fn_resize expects a np array and returns a np array 50 | img_resized = self.fn_resize(img_np) 51 | 52 | # ToTensor() converts to [0,1] only if input in uint8 53 | if img_resized.dtype == "uint8": 54 | img_t = self.transforms(np.array(img_resized)) * 255 55 | elif img_resized.dtype == "float32": 56 | img_t = self.transforms(img_resized) 57 | 58 | return img_t 59 | 60 | 61 | EXTENSIONS = { 62 | "bmp", 63 | "jpg", 64 | "jpeg", 65 | "pgm", 66 | "png", 67 | "ppm", 68 | "tif", 69 | "tiff", 70 | "webp", 71 | "npy", 72 | "JPEG", 73 | "JPG", 74 | "PNG", 75 | } 76 | -------------------------------------------------------------------------------- /metrics/metrics/clean_fid/wrappers.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | import torch 4 | from .features import build_feature_extractor, get_reference_statistics 5 | from .fid import get_batch_features, fid_from_feats 6 | from .resize import build_resizer 7 | 8 | 9 | """ 10 | A helper class that allowing adding the images one batch at a time. 11 | """ 12 | 13 | 14 | class CleanFID: 15 | def __init__(self, mode="clean", model_name="inception_v3", device="cuda"): 16 | self.real_features = [] 17 | self.gen_features = [] 18 | self.mode = mode 19 | self.device = device 20 | if model_name == "inception_v3": 21 | self.feat_model = build_feature_extractor(mode, device) 22 | self.fn_resize = build_resizer(mode) 23 | elif model_name == "clip_vit_b_32": 24 | from .clip_features import CLIP_fx, img_preprocess_clip 25 | 26 | clip_fx = CLIP_fx("ViT-B/32") 27 | self.feat_model = clip_fx 28 | self.fn_resize = img_preprocess_clip 29 | 30 | """ 31 | Funtion that takes an image (PIL.Image or np.array or torch.tensor) 32 | and returns the corresponding feature embedding vector. 33 | The image x is expected to be in range [0, 255] 34 | """ 35 | 36 | def compute_features(self, x): 37 | # if x is a PIL Image 38 | if isinstance(x, Image.Image): 39 | x_np = np.array(x) 40 | x_np_resized = self.fn_resize(x_np) 41 | x_t = torch.tensor(x_np_resized.transpose((2, 0, 1))).unsqueeze(0) 42 | x_feat = get_batch_features(x_t, self.feat_model, self.device) 43 | elif isinstance(x, np.ndarray): 44 | x_np_resized = self.fn_resize(x) 45 | x_t = ( 46 | torch.tensor(x_np_resized.transpose((2, 0, 1))) 47 | .unsqueeze(0) 48 | .to(self.device) 49 | ) 50 | # normalization happens inside the self.feat_model, expected image range here is [0,255] 51 | x_feat = get_batch_features(x_t, self.feat_model, self.device) 52 | elif isinstance(x, torch.Tensor): 53 | # pdb.set_trace() 54 | # add the batch dimension if x is passed in as C,H,W 55 | if len(x.shape) == 3: 56 | x = x.unsqueeze(0) 57 | b, c, h, w = x.shape 58 | # convert back to np array and resize 59 | l_x_np_resized = [] 60 | for _ in range(b): 61 | x_np = x[_].cpu().numpy().transpose((1, 2, 0)) 62 | l_x_np_resized.append(self.fn_resize(x_np)[None,]) 63 | x_np_resized = np.concatenate(l_x_np_resized) 64 | x_t = torch.tensor(x_np_resized.transpose((0, 3, 1, 2))).to(self.device) 65 | # normalization happens inside the self.feat_model, expected image range here is [0,255] 66 | x_feat = get_batch_features(x_t, self.feat_model, self.device) 67 | else: 68 | raise ValueError("image type could not be inferred") 69 | return x_feat 70 | 71 | """ 72 | Extract the faetures from x and add to the list of reference real images 73 | """ 74 | 75 | def add_real_images(self, x): 76 | x_feat = self.compute_features(x) 77 | self.real_features.append(x_feat) 78 | 79 | """ 80 | Extract the faetures from x and add to the list of generated images 81 | """ 82 | 83 | def add_gen_images(self, x): 84 | x_feat = self.compute_features(x) 85 | self.gen_features.append(x_feat) 86 | 87 | """ 88 | Compute FID between the real and generated images added so far 89 | """ 90 | 91 | def calculate_fid(self, verbose=True): 92 | feats1 = np.concatenate(self.real_features) 93 | feats2 = np.concatenate(self.gen_features) 94 | if verbose: 95 | print(f"# real images = {feats1.shape[0]}") 96 | print(f"# generated images = {feats2.shape[0]}") 97 | return fid_from_feats(feats1, feats2) 98 | 99 | """ 100 | Remove the real image features added so far 101 | """ 102 | 103 | def reset_real_features(self): 104 | self.real_features = [] 105 | 106 | """ 107 | Remove the generated image features added so far 108 | """ 109 | 110 | def reset_gen_features(self): 111 | self.gen_features = [] 112 | -------------------------------------------------------------------------------- /metrics/metrics/clip.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from PIL import Image 3 | import open_clip 4 | 5 | 6 | def load_open_clip_model_preprocess_and_tokenizer(device=torch.device("cuda")): 7 | clip_model, _, clip_preprocess = open_clip.create_model_and_transforms( 8 | "ViT-g-14", pretrained="laion2b_s12b_b42k", device=device 9 | ) 10 | clip_tokenizer = open_clip.get_tokenizer("ViT-g-14") 11 | return clip_model, clip_preprocess, clip_tokenizer 12 | 13 | 14 | def compute_clip_score( 15 | images, 16 | prompts, 17 | models, 18 | device=torch.device("cuda"), 19 | ): 20 | clip_model, clip_preprocess, clip_tokenizer = models 21 | with torch.no_grad(): 22 | tensors = [clip_preprocess(image) for image in images] 23 | image_processed_tensor = torch.stack(tensors, 0).to(device) 24 | image_features = clip_model.encode_image(image_processed_tensor) 25 | 26 | encoding = clip_tokenizer(prompts).to(device) 27 | text_features = clip_model.encode_text(encoding) 28 | 29 | image_features /= image_features.norm(dim=-1, keepdim=True) 30 | text_features /= text_features.norm(dim=-1, keepdim=True) 31 | 32 | return (image_features @ text_features.T).mean(-1).cpu().numpy().tolist() 33 | -------------------------------------------------------------------------------- /metrics/metrics/distributional.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | import torch 4 | from PIL import Image 5 | from tqdm.auto import tqdm 6 | from concurrent.futures import ProcessPoolExecutor 7 | from functools import partial 8 | from PIL import Image 9 | from .clean_fid import fid 10 | 11 | 12 | def save_single_image_to_temp(i, image, temp_dir): 13 | save_path = os.path.join(temp_dir, f"{i}.png") 14 | image.save(save_path, "PNG") 15 | 16 | 17 | def save_images_to_temp(images, num_workers, verbose=False): 18 | assert isinstance(images, list) and isinstance(images[0], Image.Image) 19 | temp_dir = tempfile.mkdtemp() 20 | 21 | # Using ProcessPoolExecutor to save images in parallel 22 | func = partial(save_single_image_to_temp, temp_dir=temp_dir) 23 | with ProcessPoolExecutor(max_workers=num_workers) as executor: 24 | tasks = executor.map(func, range(len(images)), images) 25 | list(tasks) if not verbose else list( 26 | tqdm( 27 | tasks, 28 | total=len(images), 29 | desc="Saving images ", 30 | ) 31 | ) 32 | return temp_dir 33 | 34 | 35 | # Compute FID between two sets of images 36 | def compute_fid( 37 | images1, 38 | images2, 39 | mode="legacy", 40 | device=None, 41 | batch_size=64, 42 | num_workers=None, 43 | verbose=False, 44 | ): 45 | # Support four types of FID scores 46 | assert mode in ["legacy", "clean", "clip"] 47 | if mode == "legacy": 48 | mode = "legacy_pytorch" 49 | model_name = "inception_v3" 50 | elif mode == "clean": 51 | mode = "clean" 52 | model_name = "inception_v3" 53 | elif mode == "clip": 54 | mode = "clean" 55 | model_name = "clip_vit_b_32" 56 | else: 57 | assert False 58 | 59 | # Set up device and num_workers 60 | if device is None: 61 | device = ( 62 | torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 63 | ) 64 | if num_workers is not None: 65 | assert 1 <= num_workers <= os.cpu_count() 66 | else: 67 | num_workers = max(torch.cuda.device_count() * 4, 8) 68 | 69 | # Check images, can be paths or lists of PIL images 70 | if not isinstance(images1, list): 71 | assert isinstance(images1, str) and os.path.exists(images1) 72 | assert isinstance(images2, str) and os.path.exists(images2) 73 | path1 = images1 74 | path2 = images2 75 | else: 76 | assert isinstance(images1, list) and isinstance(images1[0], Image.Image) 77 | assert isinstance(images2, list) and isinstance(images2[0], Image.Image) 78 | # Save images to temp dir if needed 79 | path1 = save_images_to_temp(images1, num_workers=num_workers, verbose=verbose) 80 | path2 = save_images_to_temp(images2, num_workers=num_workers, verbose=verbose) 81 | 82 | # Attempt to cache statistics for path1 83 | if not fid.test_stats_exists(name=str(os.path.abspath(path1)).replace("/", "_"), mode=mode, model_name=model_name): 84 | fid.make_custom_stats( 85 | name=str(os.path.abspath(path1)).replace("/", "_"), 86 | fdir=path1, 87 | mode=mode, 88 | model_name=model_name, 89 | device=device, 90 | num_workers=num_workers, 91 | verbose=verbose, 92 | ) 93 | fid_score = fid.compute_fid( 94 | path2, 95 | dataset_name=str(os.path.abspath(path1)).replace("/", "_"), 96 | dataset_split="custom", 97 | mode=mode, 98 | model_name=model_name, 99 | device=device, 100 | batch_size=batch_size, 101 | num_workers=num_workers, 102 | verbose=verbose, 103 | ) 104 | return fid_score 105 | -------------------------------------------------------------------------------- /metrics/metrics/image.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | from PIL import Image 5 | from skimage.metrics import ( 6 | mean_squared_error, 7 | peak_signal_noise_ratio, 8 | structural_similarity as structural_similarity_index_measure, 9 | normalized_mutual_information, 10 | ) 11 | from tqdm.auto import tqdm 12 | from concurrent.futures import ThreadPoolExecutor 13 | 14 | 15 | # Process images to numpy arrays 16 | def convert_image_pair_to_numpy(image1, image2): 17 | assert isinstance(image1, Image.Image) and isinstance(image2, Image.Image) 18 | 19 | image1_np = np.array(image1) 20 | image2_np = np.array(image2) 21 | assert image1_np.shape == image2_np.shape 22 | 23 | return image1_np, image2_np 24 | 25 | 26 | # Compute MSE between two images 27 | def compute_mse(image1, image2): 28 | image1_np, image2_np = convert_image_pair_to_numpy(image1, image2) 29 | return float(mean_squared_error(image1_np, image2_np)) 30 | 31 | 32 | # Compute PSNR between two images 33 | def compute_psnr(image1, image2): 34 | image1_np, image2_np = convert_image_pair_to_numpy(image1, image2) 35 | return float(peak_signal_noise_ratio(image1_np, image2_np)) 36 | 37 | 38 | # Compute SSIM between two images 39 | def compute_ssim(image1, image2): 40 | image1_np, image2_np = convert_image_pair_to_numpy(image1, image2) 41 | return float( 42 | structural_similarity_index_measure(image1_np, image2_np, channel_axis=2) 43 | ) 44 | 45 | 46 | # Compute NMI between two images 47 | def compute_nmi(image1, image2): 48 | image1_np, image2_np = convert_image_pair_to_numpy(image1, image2) 49 | return float(normalized_mutual_information(image1_np, image2_np)) 50 | 51 | 52 | # Compute metrics 53 | def compute_metric_repeated( 54 | images1, images2, metric_func, num_workers=None, verbose=False 55 | ): 56 | # Accept list of PIL images 57 | assert isinstance(images1, list) and isinstance(images1[0], Image.Image) 58 | assert isinstance(images2, list) and isinstance(images2[0], Image.Image) 59 | assert len(images1) == len(images2) 60 | 61 | if num_workers is not None: 62 | assert 1 <= num_workers <= os.cpu_count() 63 | else: 64 | num_workers = max(torch.cuda.device_count() * 4, 8) 65 | 66 | metric_name = metric_func.__name__.split("_")[1].upper() 67 | 68 | with ThreadPoolExecutor(max_workers=num_workers) as executor: 69 | tasks = executor.map(metric_func, images1, images2) 70 | values = ( 71 | list(tasks) 72 | if not verbose 73 | else list( 74 | tqdm( 75 | tasks, 76 | total=len(images1), 77 | desc=f"{metric_name} ", 78 | ) 79 | ) 80 | ) 81 | return values 82 | 83 | 84 | # Compute MSE between pairs of images 85 | def compute_mse_repeated(images1, images2, num_workers=None, verbose=False): 86 | return compute_metric_repeated(images1, images2, compute_mse, num_workers, verbose) 87 | 88 | 89 | # Compute PSNR between pairs of images 90 | def compute_psnr_repeated(images1, images2, num_workers=None, verbose=False): 91 | return compute_metric_repeated(images1, images2, compute_psnr, num_workers, verbose) 92 | 93 | 94 | # Compute SSIM between pairs of images 95 | def compute_ssim_repeated(images1, images2, num_workers=None, verbose=False): 96 | return compute_metric_repeated(images1, images2, compute_ssim, num_workers, verbose) 97 | 98 | 99 | # Compute NMI between pairs of images 100 | def compute_nmi_repeated(images1, images2, num_workers=None, verbose=False): 101 | return compute_metric_repeated(images1, images2, compute_nmi, num_workers, verbose) 102 | 103 | 104 | def compute_image_distance_repeated( 105 | images1, images2, metric_name, num_workers=None, verbose=False 106 | ): 107 | metric_func = { 108 | "psnr": compute_psnr, 109 | "ssim": compute_ssim, 110 | "nmi": compute_nmi, 111 | }[metric_name] 112 | return compute_metric_repeated(images1, images2, metric_func, num_workers, verbose) 113 | -------------------------------------------------------------------------------- /metrics/metrics/lpips/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | From https://github.com/richzhang/PerceptualSimilarity 3 | """ 4 | from .lpips import LPIPS 5 | -------------------------------------------------------------------------------- /metrics/metrics/lpips/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | import numpy as np 5 | import torch 6 | 7 | 8 | def normalize_tensor(in_feat, eps=1e-10): 9 | norm_factor = torch.sqrt(torch.sum(in_feat**2, dim=1, keepdim=True)) 10 | return in_feat / (norm_factor + eps) 11 | 12 | 13 | def l2(p0, p1, range=255.0): 14 | return 0.5 * np.mean((p0 / range - p1 / range) ** 2) 15 | 16 | 17 | def psnr(p0, p1, peak=255.0): 18 | return 10 * np.log10(peak**2 / np.mean((1.0 * p0 - 1.0 * p1) ** 2)) 19 | 20 | 21 | def dssim(p0, p1, range=255.0): 22 | from skimage.measure import compare_ssim 23 | 24 | return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2.0 25 | 26 | 27 | def tensor2np(tensor_obj): 28 | # change dimension of a tensor object into a numpy array 29 | return tensor_obj[0].cpu().float().numpy().transpose((1, 2, 0)) 30 | 31 | 32 | def np2tensor(np_obj): 33 | # change dimenion of np array into tensor array 34 | return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 35 | 36 | 37 | def tensor2tensorlab(image_tensor, to_norm=True, mc_only=False): 38 | # image tensor to lab tensor 39 | from skimage import color 40 | 41 | img = tensor2im(image_tensor) 42 | img_lab = color.rgb2lab(img) 43 | if mc_only: 44 | img_lab[:, :, 0] = img_lab[:, :, 0] - 50 45 | if to_norm and not mc_only: 46 | img_lab[:, :, 0] = img_lab[:, :, 0] - 50 47 | img_lab = img_lab / 100.0 48 | 49 | return np2tensor(img_lab) 50 | 51 | 52 | def tensorlab2tensor(lab_tensor, return_inbnd=False): 53 | from skimage import color 54 | import warnings 55 | 56 | warnings.filterwarnings("ignore") 57 | 58 | lab = tensor2np(lab_tensor) * 100.0 59 | lab[:, :, 0] = lab[:, :, 0] + 50 60 | 61 | rgb_back = 255.0 * np.clip(color.lab2rgb(lab.astype("float")), 0, 1) 62 | if return_inbnd: 63 | # convert back to lab, see if we match 64 | lab_back = color.rgb2lab(rgb_back.astype("uint8")) 65 | mask = 1.0 * np.isclose(lab_back, lab, atol=2.0) 66 | mask = np2tensor(np.prod(mask, axis=2)[:, :, np.newaxis]) 67 | return (im2tensor(rgb_back), mask) 68 | else: 69 | return im2tensor(rgb_back) 70 | 71 | 72 | def load_image(path): 73 | if ( 74 | path[-3:] == "bmp" 75 | or path[-3:] == "jpg" 76 | or path[-3:] == "png" 77 | or path[-4:] == "jpeg" 78 | ): 79 | import cv2 80 | 81 | return cv2.imread(path)[:, :, ::-1] 82 | else: 83 | import matplotlib.pyplot as plt 84 | 85 | img = (255 * plt.imread(path)[:, :, :3]).astype("uint8") 86 | 87 | return img 88 | 89 | 90 | def tensor2im(image_tensor, imtype=np.uint8, cent=1.0, factor=255.0 / 2.0): 91 | image_numpy = image_tensor[0].cpu().float().numpy() 92 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor 93 | return image_numpy.astype(imtype) 94 | 95 | 96 | def im2tensor(image, imtype=np.uint8, cent=1.0, factor=255.0 / 2.0): 97 | return torch.Tensor( 98 | (image / factor - cent)[:, :, :, np.newaxis].transpose((3, 2, 0, 1)) 99 | ) 100 | 101 | 102 | def tensor2vec(vector_tensor): 103 | return vector_tensor.data.cpu().numpy()[:, :, 0, 0] 104 | 105 | 106 | def voc_ap(rec, prec, use_07_metric=False): 107 | """ap = voc_ap(rec, prec, [use_07_metric]) 108 | Compute VOC AP given precision and recall. 109 | If use_07_metric is true, uses the 110 | VOC 07 11 point method (default:False). 111 | """ 112 | if use_07_metric: 113 | # 11 point metric 114 | ap = 0.0 115 | for t in np.arange(0.0, 1.1, 0.1): 116 | if np.sum(rec >= t) == 0: 117 | p = 0 118 | else: 119 | p = np.max(prec[rec >= t]) 120 | ap = ap + p / 11.0 121 | else: 122 | # correct AP calculation 123 | # first append sentinel values at the end 124 | mrec = np.concatenate(([0.0], rec, [1.0])) 125 | mpre = np.concatenate(([0.0], prec, [0.0])) 126 | 127 | # compute the precision envelope 128 | for i in range(mpre.size - 1, 0, -1): 129 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) 130 | 131 | # to calculate area under PR curve, look for points 132 | # where X axis (recall) changes value 133 | i = np.where(mrec[1:] != mrec[:-1])[0] 134 | 135 | # and sum (\Delta recall) * prec 136 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) 137 | return ap 138 | -------------------------------------------------------------------------------- /metrics/metrics/lpips/weights/v0.0/alex.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:18720f55913d0af89042f13faa7e536a6ce1444a0914e6db9461355ece1e8cd5 3 | size 5455 4 | -------------------------------------------------------------------------------- /metrics/metrics/lpips/weights/v0.0/squeeze.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:c27abd3a0145541baa50990817df58d3759c3f8154949f42af3b59b4e042d0bf 3 | size 10057 4 | -------------------------------------------------------------------------------- /metrics/metrics/lpips/weights/v0.0/vgg.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:b9e4236260c3dd988fc79d2a48d645d885afcbb21f9fd595e6744cf7419b582c 3 | size 6735 4 | -------------------------------------------------------------------------------- /metrics/metrics/lpips/weights/v0.1/alex.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:df73285e35b22355a2df87cdb6b70b343713b667eddbda73e1977e0c860835c0 3 | size 6009 4 | -------------------------------------------------------------------------------- /metrics/metrics/lpips/weights/v0.1/squeeze.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:4a5350f23600cb79923ce65bb07cbf57dca461329894153e05a1346bd531cf76 3 | size 10811 4 | -------------------------------------------------------------------------------- /metrics/metrics/lpips/weights/v0.1/vgg.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:a78928a0af1e5f0fcb1f3b9e8f8c3a2a5a3de244d830ad5c1feddc79b8432868 3 | size 7289 4 | -------------------------------------------------------------------------------- /metrics/metrics/perceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from PIL import Image 4 | from tqdm.auto import tqdm 5 | from utils import to_tensor 6 | from .lpips import LPIPS 7 | from .watson import LossProvider 8 | 9 | 10 | def load_perceptual_models(metric_name, mode, device=torch.device("cuda")): 11 | assert metric_name in ["lpips", "watson"] 12 | if metric_name == "lpips": 13 | assert mode in ["vgg", "alex"] 14 | perceptual_model = LPIPS(net=mode).to(device) 15 | elif metric_name == "watson": 16 | assert mode in ["vgg", "dft", "fft"] 17 | perceptual_model = ( 18 | LossProvider() 19 | .get_loss_function( 20 | "Watson-" + mode, colorspace="RGB", pretrained=True, reduction="none" 21 | ) 22 | .to(device) 23 | ) 24 | else: 25 | assert False 26 | return perceptual_model 27 | 28 | 29 | # Compute metric between two images 30 | def compute_metric(image1, image2, perceptual_model, device=torch.device("cuda")): 31 | assert isinstance(image1, Image.Image) and isinstance(image2, Image.Image) 32 | image1_tensor = to_tensor([image1]).to(device) 33 | image2_tensor = to_tensor([image2]).to(device) 34 | return perceptual_model(image1_tensor, image2_tensor).cpu().item() 35 | 36 | 37 | # Compute LPIPS distance between two images 38 | def compute_lpips(image1, image2, mode="alex", device=torch.device("cuda")): 39 | perceptual_model = load_perceptual_models("lpips", mode, device) 40 | return compute_metric(image1, image2, perceptual_model, device) 41 | 42 | 43 | # Compute Watson distance between two images 44 | def compute_watson(image1, image2, mode="dft", device=torch.device("cuda")): 45 | perceptual_model = load_perceptual_models("watson", mode, device) 46 | return compute_metric(image1, image2, perceptual_model, device) 47 | 48 | 49 | # Compute metrics between pairs of images 50 | def compute_perceptual_metric_repeated( 51 | images1, 52 | images2, 53 | metric_name, 54 | mode, 55 | model, 56 | device, 57 | ): 58 | # Accept list of PIL images 59 | assert isinstance(images1, list) and isinstance(images1[0], Image.Image) 60 | assert isinstance(images2, list) and isinstance(images2[0], Image.Image) 61 | assert len(images1) == len(images2) 62 | if model is None: 63 | model = load_perceptual_models(metric_name, mode).to(device) 64 | return ( 65 | model(to_tensor(images1).to(device), to_tensor(images2).to(device)) 66 | .detach() 67 | .cpu() 68 | .numpy() 69 | .flatten() 70 | .tolist() 71 | ) 72 | 73 | 74 | # Compute LPIPS distance between pairs of images 75 | def compute_lpips_repeated( 76 | images1, 77 | images2, 78 | mode="alex", 79 | model=None, 80 | device=torch.device("cuda"), 81 | ): 82 | return compute_perceptual_metric_repeated( 83 | images1, images2, "lpips", mode, model, device 84 | ) 85 | 86 | 87 | # Compute Watson distance between pairs of images 88 | def compute_watson_repeated( 89 | images1, 90 | images2, 91 | mode="dft", 92 | model=None, 93 | device=torch.device("cuda"), 94 | ): 95 | return compute_perceptual_metric_repeated( 96 | images1, images2, "watson", mode, model, device 97 | ) 98 | -------------------------------------------------------------------------------- /metrics/metrics/prompt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import GPT2LMHeadModel, GPT2TokenizerFast 3 | 4 | 5 | # Load GPT-2 large model and tokenizer 6 | def load_perplexity_model_and_tokenizer(): 7 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 8 | ppl_model = GPT2LMHeadModel.from_pretrained("gpt2-large").to(device) 9 | ppl_tokenizer = GPT2TokenizerFast.from_pretrained("gpt2-large") 10 | return ppl_model, ppl_tokenizer 11 | 12 | 13 | # Compute perplexity for a single prompt 14 | def compute_prompt_perplexity(prompt, models, stride=512): 15 | assert isinstance(prompt, str) 16 | assert isinstance(models, tuple) and len(models) == 2 17 | ppl_model, ppl_tokenizer = models 18 | encodings = ppl_tokenizer(prompt, return_tensors="pt") 19 | max_length = ppl_model.config.n_positions 20 | seq_len = encodings.input_ids.size(1) 21 | nlls = [] 22 | prev_end_loc = 0 23 | for begin_loc in range(0, seq_len, stride): 24 | end_loc = min(begin_loc + max_length, seq_len) 25 | trg_len = end_loc - prev_end_loc # may be different from stride on last loop 26 | input_ids = encodings.input_ids[:, begin_loc:end_loc].to( 27 | next(ppl_model.parameters()).device 28 | ) 29 | target_ids = input_ids.clone() 30 | target_ids[:, :-trg_len] = -100 31 | with torch.no_grad(): 32 | outputs = ppl_model(input_ids, labels=target_ids) 33 | neg_log_likelihood = outputs.loss 34 | nlls.append(neg_log_likelihood) 35 | prev_end_loc = end_loc 36 | if end_loc == seq_len: 37 | break 38 | ppl = torch.exp(torch.stack(nlls).mean()).item() 39 | return ppl 40 | -------------------------------------------------------------------------------- /metrics/metrics/watson/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | From https://github.com/facebookresearch/stable_signature 3 | """ 4 | from .loss_provider import LossProvider 5 | -------------------------------------------------------------------------------- /metrics/metrics/watson/color_wrapper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class RGB2YCbCr(nn.Module): 7 | def __init__(self): 8 | super().__init__() 9 | transf = torch.tensor( 10 | [[0.299, 0.587, 0.114], [-0.1687, -0.3313, 0.5], [0.5, -0.4187, -0.0813]] 11 | ).transpose(0, 1) 12 | self.transform = nn.Parameter(transf, requires_grad=False) 13 | bias = torch.tensor([0, 0.5, 0.5]) 14 | self.bias = nn.Parameter(bias, requires_grad=False) 15 | 16 | def forward(self, rgb): 17 | N, C, H, W = rgb.shape 18 | assert C == 3 19 | rgb = rgb.transpose(1, 3) 20 | cbcr = torch.matmul(rgb, self.transform) 21 | cbcr += self.bias 22 | return cbcr.transpose(1, 3) 23 | 24 | 25 | class ColorWrapper(nn.Module): 26 | """ 27 | Extension for single-channel loss to work on color images 28 | """ 29 | 30 | def __init__(self, lossclass, args, kwargs, trainable=False): 31 | """ 32 | Parameters: 33 | lossclass: class of the individual loss functions 34 | trainable: bool, if True parameters of the loss are trained. 35 | args: tuple, arguments for instantiation of loss fun 36 | kwargs: dict, key word arguments for instantiation of loss fun 37 | """ 38 | super().__init__() 39 | 40 | # submodules 41 | self.add_module("to_YCbCr", RGB2YCbCr()) 42 | self.add_module("ly", lossclass(*args, **kwargs)) 43 | self.add_module("lcb", lossclass(*args, **kwargs)) 44 | self.add_module("lcr", lossclass(*args, **kwargs)) 45 | 46 | # weights 47 | self.w_tild = nn.Parameter(torch.zeros(3), requires_grad=trainable) 48 | 49 | @property 50 | def w(self): 51 | return F.softmax(self.w_tild, dim=0) 52 | 53 | def forward(self, input, target): 54 | # convert color space 55 | input = self.to_YCbCr(input) 56 | target = self.to_YCbCr(target) 57 | 58 | ly = self.ly(input[:, [0], :, :], target[:, [0], :, :]) 59 | lcb = self.lcb(input[:, [1], :, :], target[:, [1], :, :]) 60 | lcr = self.lcr(input[:, [2], :, :], target[:, [2], :, :]) 61 | 62 | w = self.w 63 | 64 | return ly * w[0] + lcb * w[1] + lcr * w[2] 65 | 66 | 67 | class GreyscaleWrapper(nn.Module): 68 | """ 69 | Maps 3 channel RGB or 1 channel greyscale input to 3 greyscale channels 70 | """ 71 | 72 | def __init__(self, lossclass, args, kwargs): 73 | """ 74 | Parameters: 75 | lossclass: class of the individual loss function 76 | args: tuple, arguments for instantiation of loss fun 77 | kwargs: dict, key word arguments for instantiation of loss fun 78 | """ 79 | super().__init__() 80 | 81 | # submodules 82 | self.add_module("loss", lossclass(*args, **kwargs)) 83 | 84 | def to_greyscale(self, tensor): 85 | return ( 86 | tensor[:, [0], :, :] * 0.3 87 | + tensor[:, [1], :, :] * 0.59 88 | + tensor[:, [2], :, :] * 0.11 89 | ) 90 | 91 | def forward(self, input, target): 92 | (N, C, X, Y) = input.size() 93 | 94 | if N == 3: 95 | # convert input to greyscale 96 | input = self.to_greyscale(input) 97 | target = self.to_greyscale(target) 98 | 99 | # input in now greyscale, expand to 3 channels 100 | input = input.expand(N, 3, X, Y) 101 | target = target.expand(N, 3, X, Y) 102 | 103 | return self.loss.forward(input, target) 104 | -------------------------------------------------------------------------------- /metrics/metrics/watson/dct2d.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class Dct2d(nn.Module): 8 | """ 9 | Blockwhise 2D DCT 10 | """ 11 | 12 | def __init__(self, blocksize=8, interleaving=False): 13 | """ 14 | Parameters: 15 | blocksize: int, size of the Blocks for discrete cosine transform 16 | interleaving: bool, should the blocks interleave? 17 | """ 18 | super().__init__() # call super constructor 19 | 20 | self.blocksize = blocksize 21 | self.interleaving = interleaving 22 | 23 | if interleaving: 24 | self.stride = self.blocksize // 2 25 | else: 26 | self.stride = self.blocksize 27 | 28 | # precompute DCT weight matrix 29 | A = np.zeros((blocksize, blocksize)) 30 | for i in range(blocksize): 31 | c_i = 1 / np.sqrt(2) if i == 0 else 1.0 32 | for n in range(blocksize): 33 | A[i, n] = ( 34 | np.sqrt(2 / blocksize) 35 | * c_i 36 | * np.cos((2 * n + 1) / (blocksize * 2) * i * np.pi) 37 | ) 38 | 39 | # set up conv layer 40 | self.A = nn.Parameter(torch.tensor(A, dtype=torch.float32), requires_grad=False) 41 | self.unfold = torch.nn.Unfold( 42 | kernel_size=blocksize, padding=0, stride=self.stride 43 | ) 44 | return 45 | 46 | def forward(self, x): 47 | """ 48 | performs 2D blockwhise DCT 49 | 50 | Parameters: 51 | x: tensor of dimension (N, 1, h, w) 52 | 53 | Return: 54 | tensor of dimension (N, k, blocksize, blocksize) 55 | where the 2nd dimension indexes the block. Dimensions 3 and 4 are the block DCT coefficients 56 | """ 57 | 58 | (N, C, H, W) = x.shape 59 | assert C == 1, "DCT is only implemented for a single channel" 60 | assert H >= self.blocksize, "Input too small for blocksize" 61 | assert W >= self.blocksize, "Input too small for blocksize" 62 | assert (H % self.stride == 0) and ( 63 | W % self.stride == 0 64 | ), "FFT is only for dimensions divisible by the blocksize" 65 | 66 | # unfold to blocks 67 | x = self.unfold(x) 68 | # now shape (N, blocksize**2, k) 69 | (N, _, k) = x.shape 70 | x = x.view(-1, self.blocksize, self.blocksize, k).permute(0, 3, 1, 2) 71 | # now shape (N, #k, blocksize, blocksize) 72 | # perform DCT 73 | coeff = self.A.matmul(x).matmul(self.A.transpose(0, 1)) 74 | 75 | return coeff 76 | 77 | def inverse(self, coeff, output_shape): 78 | """ 79 | performs 2D blockwhise iDCT 80 | 81 | Parameters: 82 | coeff: tensor of dimension (N, k, blocksize, blocksize) 83 | where the 2nd dimension indexes the block. Dimensions 3 and 4 are the block DCT coefficients 84 | output_shape: (h, w) dimensions of the reconstructed image 85 | 86 | Return: 87 | tensor of dimension (N, 1, h, w) 88 | """ 89 | if self.interleaving: 90 | raise Exception( 91 | "Inverse block DCT is not implemented for interleaving blocks!" 92 | ) 93 | 94 | # perform iDCT 95 | x = self.A.transpose(0, 1).matmul(coeff).matmul(self.A) 96 | (N, k, _, _) = x.shape 97 | x = x.permute(0, 2, 3, 1).view(-1, self.blocksize**2, k) 98 | x = F.fold( 99 | x, 100 | output_size=(output_shape[-2], output_shape[-1]), 101 | kernel_size=self.blocksize, 102 | padding=0, 103 | stride=self.blocksize, 104 | ) 105 | return x 106 | -------------------------------------------------------------------------------- /metrics/metrics/watson/rfft2d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.fft as fft 4 | import torch.nn.functional as F 5 | 6 | 7 | class Rfft2d(nn.Module): 8 | """ 9 | Blockwhise 2D FFT 10 | for fixed blocksize of 8x8 11 | """ 12 | 13 | def __init__(self, blocksize=8, interleaving=False): 14 | """ 15 | Parameters: 16 | """ 17 | super().__init__() # call super constructor 18 | 19 | self.blocksize = blocksize 20 | self.interleaving = interleaving 21 | if interleaving: 22 | self.stride = self.blocksize // 2 23 | else: 24 | self.stride = self.blocksize 25 | 26 | self.unfold = torch.nn.Unfold( 27 | kernel_size=self.blocksize, padding=0, stride=self.stride 28 | ) 29 | return 30 | 31 | def forward(self, x): 32 | """ 33 | performs 2D blockwhise DCT 34 | 35 | Parameters: 36 | x: tensor of dimension (N, 1, h, w) 37 | 38 | Return: 39 | tensor of dimension (N, k, b, b/2, 2) 40 | where the 2nd dimension indexes the block. Dimensions 3 and 4 are the block real FFT coefficients. 41 | The last dimension is pytorches representation of complex values 42 | """ 43 | 44 | (N, C, H, W) = x.shape 45 | assert C == 1, "FFT is only implemented for a single channel" 46 | assert H >= self.blocksize, "Input too small for blocksize" 47 | assert W >= self.blocksize, "Input too small for blocksize" 48 | assert (H % self.stride == 0) and ( 49 | W % self.stride == 0 50 | ), "FFT is only for dimensions divisible by the blocksize" 51 | 52 | # unfold to blocks 53 | x = self.unfold(x) 54 | # now shape (N, 64, k) 55 | (N, _, k) = x.shape 56 | x = x.view(-1, self.blocksize, self.blocksize, k).permute(0, 3, 1, 2) 57 | # now shape (N, #k, b, b) 58 | # perform DCT 59 | coeff = fft.rfft(x) 60 | coeff = torch.view_as_real(coeff) 61 | 62 | return coeff / self.blocksize**2 63 | 64 | def inverse(self, coeff, output_shape): 65 | """ 66 | performs 2D blockwhise inverse rFFT 67 | 68 | Parameters: 69 | output_shape: Tuple, dimensions of the outpus sample 70 | """ 71 | if self.interleaving: 72 | raise Exception( 73 | "Inverse block FFT is not implemented for interleaving blocks!" 74 | ) 75 | 76 | # perform iRFFT 77 | x = fft.irfft(coeff, dim=2, signal_sizes=(self.blocksize, self.blocksize)) 78 | (N, k, _, _) = x.shape 79 | x = x.permute(0, 2, 3, 1).view(-1, self.blocksize**2, k) 80 | x = F.fold( 81 | x, 82 | output_size=(output_shape[-2], output_shape[-1]), 83 | kernel_size=self.blocksize, 84 | padding=0, 85 | stride=self.blocksize, 86 | ) 87 | return x * (self.blocksize**2) 88 | -------------------------------------------------------------------------------- /metrics/metrics/watson/shift_wrapper.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import numpy as np 3 | 4 | 5 | class ShiftWrapper(nn.Module): 6 | """ 7 | Extension for 2-dimensional inout loss functions. 8 | Shifts the inputs by up to 4 pixels. Uses replication padding. 9 | """ 10 | 11 | def __init__(self, lossclass, args, kwargs): 12 | """ 13 | Parameters: 14 | lossclass: class of the individual loss functions 15 | trainable: bool, if True parameters of the loss are trained. 16 | args: tuple, arguments for instantiation of loss fun 17 | kwargs: dict, key word arguments for instantiation of loss fun 18 | """ 19 | super().__init__() 20 | 21 | # submodules 22 | self.add_module("loss", lossclass(*args, **kwargs)) 23 | 24 | # shift amount 25 | self.max_shift = 8 26 | 27 | # padding 28 | self.pad = nn.ReplicationPad2d(self.max_shift // 2) 29 | 30 | def forward(self, input, target): 31 | # convert color space 32 | input = self.pad(input) 33 | target = self.pad(target) 34 | 35 | shift_x = np.random.randint(self.max_shift) 36 | shift_y = np.random.randint(self.max_shift) 37 | 38 | input = input[ 39 | :, 40 | :, 41 | shift_x : -(self.max_shift - shift_x), 42 | shift_y : -(self.max_shift - shift_y), 43 | ] 44 | target = target[ 45 | :, 46 | :, 47 | shift_x : -(self.max_shift - shift_x), 48 | shift_y : -(self.max_shift - shift_y), 49 | ] 50 | 51 | return self.loss(input, target) 52 | -------------------------------------------------------------------------------- /metrics/metrics/watson/ssim.py: -------------------------------------------------------------------------------- 1 | # SSIM implementation from https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/pytorch_ssim/__init__.py 2 | import torch 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from math import exp 6 | 7 | 8 | def gaussian(window_size, sigma): 9 | gauss = torch.Tensor( 10 | [ 11 | exp(-((x - window_size // 2) ** 2) / float(2 * sigma**2)) 12 | for x in range(window_size) 13 | ] 14 | ) 15 | return gauss / gauss.sum() 16 | 17 | 18 | def create_window(window_size, channel): 19 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 20 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 21 | window = Variable( 22 | _2D_window.expand(channel, 1, window_size, window_size).contiguous() 23 | ) 24 | return window 25 | 26 | 27 | def _ssim(img1, img2, window, window_size, channel, size_average=True): 28 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 29 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 30 | 31 | mu1_sq = mu1.pow(2) 32 | mu2_sq = mu2.pow(2) 33 | mu1_mu2 = mu1 * mu2 34 | 35 | sigma1_sq = ( 36 | F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 37 | ) 38 | sigma2_sq = ( 39 | F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 40 | ) 41 | sigma12 = ( 42 | F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) 43 | - mu1_mu2 44 | ) 45 | 46 | C1 = 0.01**2 47 | C2 = 0.03**2 48 | 49 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ( 50 | (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2) 51 | ) 52 | 53 | if size_average: 54 | return ssim_map.mean() 55 | else: 56 | return ssim_map.mean(1).mean(1).mean(1) 57 | 58 | 59 | class SSIM(torch.nn.Module): 60 | def __init__(self, window_size=11, size_average=True): 61 | super(SSIM, self).__init__() 62 | self.window_size = window_size 63 | self.size_average = size_average 64 | self.channel = 1 65 | self.window = create_window(window_size, self.channel) 66 | 67 | def forward(self, img1, img2): 68 | (_, channel, _, _) = img1.size() 69 | 70 | if channel == self.channel and self.window.data.type() == img1.data.type(): 71 | window = self.window 72 | else: 73 | window = create_window(self.window_size, channel) 74 | 75 | if img1.is_cuda: 76 | window = window.cuda(img1.get_device()) 77 | window = window.type_as(img1) 78 | 79 | self.window = window 80 | self.channel = channel 81 | 82 | return 1 - _ssim( 83 | img1, img2, window, self.window_size, channel, self.size_average 84 | ) 85 | 86 | 87 | def ssim(img1, img2, window_size=11, size_average=True): 88 | (_, channel, _, _) = img1.size() 89 | window = create_window(window_size, channel) 90 | 91 | if img1.is_cuda: 92 | window = window.cuda(img1.get_device()) 93 | window = window.type_as(img1) 94 | 95 | return _ssim(img1, img2, window, window_size, channel, size_average) 96 | -------------------------------------------------------------------------------- /metrics/metrics/watson/weights/gray_adaptive_trial0.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:db9673847841ebdfa130f9d861e6e8a27426cc4f88e5f8b702639eda7a089667 3 | size 98984 4 | -------------------------------------------------------------------------------- /metrics/metrics/watson/weights/gray_pnet_lin_squeeze_trial0.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:09edfe8565911f6c275646e1a5a5a04a17997e8175bb79bbc10c6ca3e87e602a 3 | size 10811 4 | -------------------------------------------------------------------------------- /metrics/metrics/watson/weights/gray_pnet_lin_vgg_trial0.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:bbf8ea727bac61aaaefba7fae63a10547c03d4cd5b81a7ba05d8003a1ca96788 3 | size 7297 4 | -------------------------------------------------------------------------------- /metrics/metrics/watson/weights/gray_watson_dct_trial0.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:6401e9d3a430ac5dac691f93f1b9b61562ce13457a26b4dcb87eea39c5ea7ff4 3 | size 1489 4 | -------------------------------------------------------------------------------- /metrics/metrics/watson/weights/gray_watson_fft_trial0.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:74fda350525e4660671bc957944d6e9773b3989effc5aa9d2b73eccd5943584e 3 | size 1304 4 | -------------------------------------------------------------------------------- /metrics/metrics/watson/weights/gray_watson_vgg_trial0.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:1ee5a769bde7e8eea181c4082a4bbfaf0161fac1cd0384de30f82bc35ba20ba2 3 | size 58871350 4 | -------------------------------------------------------------------------------- /metrics/metrics/watson/weights/rgb_adaptive_trial0.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:3b2d5d28b8468f5e1d119f1fa709199b6c645dc7f58538d939ebc27287fb95d4 3 | size 98984 4 | -------------------------------------------------------------------------------- /metrics/metrics/watson/weights/rgb_pnet_lin_squeeze_trial0.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:7f4c84e909231e6eb253f55f8484e5895c2044178fe3316912ef4ebd50dc7b14 3 | size 10811 4 | -------------------------------------------------------------------------------- /metrics/metrics/watson/weights/rgb_pnet_lin_vgg_trial0.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:df4bad76bf9d8bca8466f4e34adc12289dbce25c6e040cba3a5152c24cdf3682 3 | size 7297 4 | -------------------------------------------------------------------------------- /metrics/metrics/watson/weights/rgb_watson_dct_trial0.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:bde40e0bc1912060da45b4487df660450ba91fa44c031f22a144c661d99488ef 3 | size 4326 4 | -------------------------------------------------------------------------------- /metrics/metrics/watson/weights/rgb_watson_fft_trial0.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:98386c0e0841fc346ab4c7d296788ed0af81f2ca8866a885be4405d03321a717 3 | size 3771 4 | -------------------------------------------------------------------------------- /metrics/metrics/watson/weights/rgb_watson_vgg_trial0.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:e2bc2bb44d8e5a484c47b6a6a695d3771650a8fb2904e90aaedfe720fd7da796 3 | size 58871350 4 | -------------------------------------------------------------------------------- /metrics/perceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from PIL import Image 4 | from tqdm.auto import tqdm 5 | from utils import to_tensor 6 | from .lpips import LPIPS 7 | from .watson import LossProvider 8 | 9 | 10 | def load_perceptual_models(metric_name, mode, device=torch.device("cuda")): 11 | assert metric_name in ["lpips", "watson"] 12 | if metric_name == "lpips": 13 | assert mode in ["vgg", "alex"] 14 | perceptual_model = LPIPS(net=mode).to(device) 15 | elif metric_name == "watson": 16 | assert mode in ["vgg", "dft", "fft"] 17 | perceptual_model = ( 18 | LossProvider() 19 | .get_loss_function( 20 | "Watson-" + mode, colorspace="RGB", pretrained=True, reduction="none" 21 | ) 22 | .to(device) 23 | ) 24 | else: 25 | assert False 26 | return perceptual_model 27 | 28 | 29 | # Compute metric between two images 30 | def compute_metric(image1, image2, perceptual_model, device=torch.device("cuda")): 31 | assert isinstance(image1, Image.Image) and isinstance(image2, Image.Image) 32 | image1_tensor = to_tensor([image1]).to(device) 33 | image2_tensor = to_tensor([image2]).to(device) 34 | return perceptual_model(image1_tensor, image2_tensor).cpu().item() 35 | 36 | 37 | # Compute LPIPS distance between two images 38 | def compute_lpips(image1, image2, mode="alex", device=torch.device("cuda")): 39 | perceptual_model = load_perceptual_models("lpips", mode, device) 40 | return compute_metric(image1, image2, perceptual_model, device) 41 | 42 | 43 | # Compute Watson distance between two images 44 | def compute_watson(image1, image2, mode="dft", device=torch.device("cuda")): 45 | perceptual_model = load_perceptual_models("watson", mode, device) 46 | return compute_metric(image1, image2, perceptual_model, device) 47 | 48 | 49 | # Compute metrics between pairs of images 50 | def compute_perceptual_metric_repeated( 51 | images1, 52 | images2, 53 | metric_name, 54 | mode, 55 | model, 56 | device, 57 | ): 58 | # Accept list of PIL images 59 | assert isinstance(images1, list) and isinstance(images1[0], Image.Image) 60 | assert isinstance(images2, list) and isinstance(images2[0], Image.Image) 61 | assert len(images1) == len(images2) 62 | if model is None: 63 | model = load_perceptual_models(metric_name, mode).to(device) 64 | return ( 65 | model(to_tensor(images1).to(device), to_tensor(images2).to(device)) 66 | .detach() 67 | .cpu() 68 | .numpy() 69 | .flatten() 70 | .tolist() 71 | ) 72 | 73 | 74 | # Compute LPIPS distance between pairs of images 75 | def compute_lpips_repeated( 76 | images1, 77 | images2, 78 | mode="alex", 79 | model=None, 80 | device=torch.device("cuda"), 81 | ): 82 | return compute_perceptual_metric_repeated( 83 | images1, images2, "lpips", mode, model, device 84 | ) 85 | 86 | 87 | # Compute Watson distance between pairs of images 88 | def compute_watson_repeated( 89 | images1, 90 | images2, 91 | mode="dft", 92 | model=None, 93 | device=torch.device("cuda"), 94 | ): 95 | return compute_perceptual_metric_repeated( 96 | images1, images2, "watson", mode, model, device 97 | ) 98 | -------------------------------------------------------------------------------- /metrics/prompt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import GPT2LMHeadModel, GPT2TokenizerFast 3 | 4 | 5 | # Load GPT-2 large model and tokenizer 6 | def load_perplexity_model_and_tokenizer(): 7 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 8 | ppl_model = GPT2LMHeadModel.from_pretrained("gpt2-large").to(device) 9 | ppl_tokenizer = GPT2TokenizerFast.from_pretrained("gpt2-large") 10 | return ppl_model, ppl_tokenizer 11 | 12 | 13 | # Compute perplexity for a single prompt 14 | def compute_prompt_perplexity(prompt, models, stride=512): 15 | assert isinstance(prompt, str) 16 | assert isinstance(models, tuple) and len(models) == 2 17 | ppl_model, ppl_tokenizer = models 18 | encodings = ppl_tokenizer(prompt, return_tensors="pt") 19 | max_length = ppl_model.config.n_positions 20 | seq_len = encodings.input_ids.size(1) 21 | nlls = [] 22 | prev_end_loc = 0 23 | for begin_loc in range(0, seq_len, stride): 24 | end_loc = min(begin_loc + max_length, seq_len) 25 | trg_len = end_loc - prev_end_loc # may be different from stride on last loop 26 | input_ids = encodings.input_ids[:, begin_loc:end_loc].to( 27 | next(ppl_model.parameters()).device 28 | ) 29 | target_ids = input_ids.clone() 30 | target_ids[:, :-trg_len] = -100 31 | with torch.no_grad(): 32 | outputs = ppl_model(input_ids, labels=target_ids) 33 | neg_log_likelihood = outputs.loss 34 | nlls.append(neg_log_likelihood) 35 | prev_end_loc = end_loc 36 | if end_loc == seq_len: 37 | break 38 | ppl = torch.exp(torch.stack(nlls).mean()).item() 39 | return ppl 40 | -------------------------------------------------------------------------------- /metrics/workflow_a_small.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/metrics/workflow_a_small.png -------------------------------------------------------------------------------- /regeneration/__init__.py: -------------------------------------------------------------------------------- 1 | from regen import regen_diff, rinse_2xDiff, rinse_4xDiff, regen_vae 2 | -------------------------------------------------------------------------------- /requirements_attack.txt: -------------------------------------------------------------------------------- 1 | pytorch-lightning==1.4.2 2 | omegaconf==2.1.1 3 | einops==0.3.0 4 | torchattacks==3.5.1 5 | transformers==4.32.1 6 | -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers -------------------------------------------------------------------------------- /requirements_cli.txt: -------------------------------------------------------------------------------- 1 | click>=8.1.7 2 | datasets>=2.14.6 3 | diffusers>=0.23.0 4 | GitPython>=3.1.40 5 | gradio>=4.2.0 6 | imageio>=2.32.0 7 | ipython>=8.16.1 8 | matplotlib>=3.8.1 9 | open_clip_torch>=2.23.0 10 | opencv_python>=4.8.1.78 11 | scikit_learn>=1.3.2 12 | scipy>=1.11.3 13 | scikit-image>=0.0 14 | transformers>=4.35.0 15 | onnx>=1.15.0 16 | onnxruntime-gpu>=1.16.2 17 | clip @ git+https://github.com/openai/CLIP.git@a1d071733d7111c9c014f024669f959182114e33 18 | orjson>=3.9.10 19 | huggingface-hub>=0.17.3 20 | tqdm>=4.66.1 21 | plotly>=5.18.0 22 | plotly-express>=0.4.1 23 | python-dotenv==1.0.0 24 | -------------------------------------------------------------------------------- /requirements_space.txt: -------------------------------------------------------------------------------- 1 | gitpython==3.1.40 2 | python-dotenv==1.0.0 3 | matplotlib==3.8.1 4 | numpy==1.24.1 5 | scipy==1.11.3 6 | scikit-learn==1.3.2 7 | Pillow==9.3.0 8 | pandas==2.1.3 9 | joblib==1.3.2 10 | plotly==5.18.0 11 | plotly-express==0.4.1 -------------------------------------------------------------------------------- /scripts/chmod.py: -------------------------------------------------------------------------------- 1 | import os 2 | import stat 3 | import concurrent.futures 4 | from tqdm.auto import tqdm 5 | import dotenv 6 | 7 | dotenv.load_dotenv(override=False) 8 | 9 | 10 | def change_permission(path, progress, current_user_id): 11 | try: 12 | if os.stat(path).st_uid == current_user_id: 13 | current_permissions = stat.S_IMODE(os.lstat(path).st_mode) 14 | os.chmod(path, current_permissions | stat.S_IWGRP) 15 | except Exception as e: 16 | print(f"Error changing permissions for {path}: {e}") 17 | finally: 18 | progress.update(1) 19 | progress.refresh() 20 | 21 | 22 | def walk_directory(directory): 23 | paths = [] 24 | for root, dirs, files in os.walk(directory): 25 | paths.append(root) 26 | for name in dirs + files: 27 | paths.append(os.path.join(root, name)) 28 | return paths 29 | 30 | 31 | def parallel_walk(root_dir): 32 | with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor: 33 | all_paths = [root_dir] 34 | futures = [] 35 | progress = tqdm(desc="Walking directories", unit="dir") 36 | 37 | for first_level_dir in os.listdir(root_dir): 38 | first_level_path = os.path.join(root_dir, first_level_dir) 39 | if os.path.isdir(first_level_path): 40 | for second_level_dir in os.listdir(first_level_path): 41 | second_level_path = os.path.join(first_level_path, second_level_dir) 42 | if os.path.isdir(second_level_path): 43 | future = executor.submit(walk_directory, second_level_path) 44 | futures.append(future) 45 | else: 46 | all_paths.append(second_level_path) # Append second-level files 47 | all_paths.append(first_level_path) # Append first-level folders and files 48 | 49 | for future in concurrent.futures.as_completed(futures): 50 | all_paths.extend(future.result()) 51 | progress.update(1) 52 | 53 | return all_paths 54 | 55 | 56 | def chmod_parallel(paths, current_user_id): 57 | with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor, tqdm( 58 | total=len(paths), desc="Changing permissions", unit="file" 59 | ) as progress: 60 | futures = [ 61 | executor.submit(change_permission, path, progress, current_user_id) 62 | for path in paths 63 | ] 64 | concurrent.futures.wait(futures) 65 | 66 | 67 | def main(): 68 | root_dir = os.path.join(os.environ.get("DATA_DIR"), "attacked") 69 | print(f"Collecting paths in {root_dir}") 70 | paths = parallel_walk(root_dir) 71 | print(f"Changing permissions for all folders and files in {root_dir}") 72 | chmod_parallel(paths, os.getuid()) 73 | 74 | 75 | if __name__ == "__main__": 76 | main() 77 | -------------------------------------------------------------------------------- /scripts/status.py: -------------------------------------------------------------------------------- 1 | import click 2 | import os 3 | from PIL import Image, ImageDraw 4 | from dev import ( 5 | LIMIT, 6 | check_file_existence, 7 | parse_image_dir_path, 8 | encode_image_to_string, 9 | save_json, 10 | ) 11 | 12 | 13 | def create_placeholder_image(size=512): 14 | # Create a 512x512 image with 50% gray background 15 | image = Image.new("RGB", (size, size), (128, 128, 128)) 16 | draw = ImageDraw.Draw(image) 17 | # Define the dark red color 18 | dark_red = (139, 0, 0) 19 | # Draw two lines to form the cross 20 | # Line from top-left to bottom-right 21 | draw.line((0, 0, 511, 511), fill=dark_red, width=10) 22 | # Line from top-right to bottom-left 23 | draw.line((511, 0, 0, 511), fill=dark_red, width=10) 24 | return image 25 | 26 | 27 | def get_image_dir_thumbnails(path, sampled, limit=5000): 28 | thumbnails = [] 29 | for i in range(limit): 30 | if i in sampled: 31 | image_path = os.path.join(path, f"{i}.png") 32 | if os.path.exists(image_path): 33 | thumbnails.append(encode_image_to_string(Image.open(image_path))) 34 | else: 35 | thumbnails.append(encode_image_to_string(create_placeholder_image())) 36 | else: 37 | thumbnails.append(None) 38 | return thumbnails 39 | 40 | 41 | @click.command() 42 | @click.option( 43 | "--path", "-p", type=str, default=os.getcwd(), help="Path to image directory" 44 | ) 45 | @click.option("--dry", "-d", is_flag=True, default=False, help="Dry run") 46 | @click.option("--quiet", "-q", is_flag=True, default=False, help="Quiet mode") 47 | def main(path, dry, quiet, limit=LIMIT): 48 | dataset_name, _, _, _ = parse_image_dir_path(path, quiet=quiet) 49 | existences = check_file_existence(path, name_pattern="{}.png", limit=limit) 50 | if not quiet: 51 | print(f"Found {sum(existences)} images out of {limit}") 52 | thumbnails = get_image_dir_thumbnails(path, sampled=[0, 1, 10, 100], limit=limit) 53 | data = {} 54 | for i in range(limit): 55 | data[str(i)] = {"exist": existences[i], "thumbnail": thumbnails[i]} 56 | json_path = os.path.join( 57 | os.environ.get("RESULT_DIR"), 58 | dataset_name, 59 | f"{str(path).split('/')[-1]}-status.json", 60 | ) 61 | save_json(data, json_path) 62 | if not quiet: 63 | print(f"Image directory status saved to {json_path}") 64 | 65 | 66 | if __name__ == "__main__": 67 | main() 68 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | # Reading requirements from 'requirements_cli.txt' 4 | with open("requirements_cli.txt") as f: 5 | requirements = f.read().splitlines() 6 | 7 | setup( 8 | name="wmbench", 9 | version="0.2.1", 10 | packages=find_packages(), 11 | py_modules=["cli"], 12 | install_requires=requirements, 13 | entry_points={ 14 | "console_scripts": ["wmbench=cli:cli"] # Pointing to the cli function in cli.py 15 | }, 16 | # Other metadata 17 | ) 18 | -------------------------------------------------------------------------------- /shell_scripts/install_dependencies.sh: -------------------------------------------------------------------------------- 1 | # Plase run this script in the root directory of this repo 2 | 3 | # Virtual environment 4 | python3 -m venv venv 5 | source venv/bin/activate 6 | 7 | # Dependencies 8 | # Install and upgrade jupyter`` 9 | pip install --upgrade pip ipython jupyter ipywidgets python-dotenv 10 | # Install dependences (on CUDA 11.8) 11 | # PyTorch 2.1.0 12 | pip install torch==2.1.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 13 | # Huggingface libraries 14 | pip install transformers diffusers 'datasets[vision]' ftfy 15 | pip install -U xformers --index-url https://download.pytorch.org/whl/cu118 16 | # Other machine learning libraries 17 | pip install onnx onnxruntime-gpu torchmetrics open_clip_torch torchattacks scikit-learn scikit-image pandas 18 | # Data processing libraries 19 | pip install pycocotools matplotlib imageio opencv-python 20 | # Metric libraries 21 | pip install git+https://github.com/openai/CLIP.git 22 | # Parallel libraries 23 | pip install accelerate deepspeed 24 | # HF space and gradio libraries 25 | pip install huggingface-hub gitpython gradio==4.3.0 plotly plotly-express wordcloud 26 | # Other libraries 27 | 28 | # Fix CUDNN issue for libnvrtc.so, see https://stackoverflow.com/questions/76216778/userwarning-applied-workaround-for-cudnn-issue-install-nvrtc-so 29 | cd venv/lib/python3.10/site-packages/torch/lib 30 | ln -s libnvrtc-*.so.11.2 libnvrtc.so 31 | cd - 32 | 33 | # Fix vscode jupyter issue, see https://github.com/microsoft/vscode-jupyter/issues/14618 34 | pip install ipython==8.16.1 -------------------------------------------------------------------------------- /static/images/2d.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/2d.jpg -------------------------------------------------------------------------------- /static/images/2d_ident.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/2d_ident.jpg -------------------------------------------------------------------------------- /static/images/2d_tree_ident.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/2d_tree_ident.jpg -------------------------------------------------------------------------------- /static/images/2x_regen-100.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/2x_regen-100.jpg -------------------------------------------------------------------------------- /static/images/2x_regen-20.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/2x_regen-20.jpg -------------------------------------------------------------------------------- /static/images/4x_regen-10.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/4x_regen-10.jpg -------------------------------------------------------------------------------- /static/images/4x_regen-50.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/4x_regen-50.jpg -------------------------------------------------------------------------------- /static/images/4x_regen_kl_vae-16.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/4x_regen_kl_vae-16.jpg -------------------------------------------------------------------------------- /static/images/4x_regen_kl_vae-4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/4x_regen_kl_vae-4.jpg -------------------------------------------------------------------------------- /static/images/adv_cls_wm1_wm2_0.01_50_warm-2-tree_ring.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/adv_cls_wm1_wm2_0.01_50_warm-2-tree_ring.jpg -------------------------------------------------------------------------------- /static/images/adv_cls_wm1_wm2_0.01_50_warm-8-tree_ring.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/adv_cls_wm1_wm2_0.01_50_warm-8-tree_ring.jpg -------------------------------------------------------------------------------- /static/images/adv_emb.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/adv_emb.jpg -------------------------------------------------------------------------------- /static/images/adv_emb_clip_untg_alphaRatio_0.05_step_200-16-tree_ring.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/adv_emb_clip_untg_alphaRatio_0.05_step_200-16-tree_ring.jpg -------------------------------------------------------------------------------- /static/images/adv_emb_clip_untg_alphaRatio_0.05_step_200-2-tree_ring.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/adv_emb_clip_untg_alphaRatio_0.05_step_200-2-tree_ring.jpg -------------------------------------------------------------------------------- /static/images/adv_emb_clip_untg_alphaRatio_0.05_step_200-8-tree_ring.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/adv_emb_clip_untg_alphaRatio_0.05_step_200-8-tree_ring.jpg -------------------------------------------------------------------------------- /static/images/adv_emb_coco.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/adv_emb_coco.jpg -------------------------------------------------------------------------------- /static/images/adv_emb_diff.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/adv_emb_diff.jpg -------------------------------------------------------------------------------- /static/images/adv_emb_same_vae_untg-2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/adv_emb_same_vae_untg-2.jpg -------------------------------------------------------------------------------- /static/images/adv_emb_same_vae_untg-8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/adv_emb_same_vae_untg-8.jpg -------------------------------------------------------------------------------- /static/images/adv_spoof.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/adv_spoof.jpg -------------------------------------------------------------------------------- /static/images/adv_su.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/adv_su.jpg -------------------------------------------------------------------------------- /static/images/all_fig_coco_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/all_fig_coco_1.jpg -------------------------------------------------------------------------------- /static/images/all_fig_coco_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/all_fig_coco_2.jpg -------------------------------------------------------------------------------- /static/images/all_fig_dalle_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/all_fig_dalle_1.jpg -------------------------------------------------------------------------------- /static/images/all_fig_dalle_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/all_fig_dalle_2.jpg -------------------------------------------------------------------------------- /static/images/all_fig_diff_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/all_fig_diff_1.jpg -------------------------------------------------------------------------------- /static/images/all_fig_diff_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/all_fig_diff_2.jpg -------------------------------------------------------------------------------- /static/images/bench_watermarks_detect 2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/bench_watermarks_detect 2.jpg -------------------------------------------------------------------------------- /static/images/bench_watermarks_detect.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/bench_watermarks_detect.jpg -------------------------------------------------------------------------------- /static/images/bench_watermarks_ident 2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/bench_watermarks_ident 2.jpg -------------------------------------------------------------------------------- /static/images/bench_watermarks_ident.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/bench_watermarks_ident.jpg -------------------------------------------------------------------------------- /static/images/carousel1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/carousel1.jpg -------------------------------------------------------------------------------- /static/images/carousel2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/carousel2.jpg -------------------------------------------------------------------------------- /static/images/carousel3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/carousel3.jpg -------------------------------------------------------------------------------- /static/images/carousel4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/carousel4.jpg -------------------------------------------------------------------------------- /static/images/dataset_dalle3_examples.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/dataset_dalle3_examples.jpg -------------------------------------------------------------------------------- /static/images/dataset_dalle3_wordcloud.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/dataset_dalle3_wordcloud.jpg -------------------------------------------------------------------------------- /static/images/dataset_diffusiondb_examples.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/dataset_diffusiondb_examples.jpg -------------------------------------------------------------------------------- /static/images/dataset_diffusiondb_wordcloud.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/dataset_diffusiondb_wordcloud.jpg -------------------------------------------------------------------------------- /static/images/dataset_mscoco_examples.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/dataset_mscoco_examples.jpg -------------------------------------------------------------------------------- /static/images/dataset_mscoco_wordcloud.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/dataset_mscoco_wordcloud.jpg -------------------------------------------------------------------------------- /static/images/dist_com1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/dist_com1.jpg -------------------------------------------------------------------------------- /static/images/dist_com2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/dist_com2.jpg -------------------------------------------------------------------------------- /static/images/distcom-deg-0.15.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/distcom-deg-0.15.jpg -------------------------------------------------------------------------------- /static/images/distcom-geo-0.15.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/distcom-geo-0.15.jpg -------------------------------------------------------------------------------- /static/images/distcom-photo-0.15.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/distcom-photo-0.15.jpg -------------------------------------------------------------------------------- /static/images/distortion_combo_all-0.04.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/distortion_combo_all-0.04.jpg -------------------------------------------------------------------------------- /static/images/distortion_combo_all-0.2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/distortion_combo_all-0.2.jpg -------------------------------------------------------------------------------- /static/images/example_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/example_1.jpg -------------------------------------------------------------------------------- /static/images/example_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/example_2.jpg -------------------------------------------------------------------------------- /static/images/example_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/example_3.jpg -------------------------------------------------------------------------------- /static/images/example_4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/example_4.jpg -------------------------------------------------------------------------------- /static/images/example_5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/example_5.jpg -------------------------------------------------------------------------------- /static/images/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/favicon.ico -------------------------------------------------------------------------------- /static/images/illu_adv_real_wm.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/illu_adv_real_wm.jpg -------------------------------------------------------------------------------- /static/images/illu_adv_unwm_wm.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/illu_adv_unwm_wm.jpg -------------------------------------------------------------------------------- /static/images/illu_adv_wm1_wm2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/illu_adv_wm1_wm2.jpg -------------------------------------------------------------------------------- /static/images/legend.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/legend.jpg -------------------------------------------------------------------------------- /static/images/no_attack.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/no_attack.jpg -------------------------------------------------------------------------------- /static/images/no_watermark.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/no_watermark.jpg -------------------------------------------------------------------------------- /static/images/problem.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/problem.gif -------------------------------------------------------------------------------- /static/images/quality_metric_cdf_normalize_range.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/quality_metric_cdf_normalize_range.jpg -------------------------------------------------------------------------------- /static/images/radar_iden_100.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/radar_iden_100.jpg -------------------------------------------------------------------------------- /static/images/radar_iden_1000.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/radar_iden_1000.jpg -------------------------------------------------------------------------------- /static/images/radar_iden_1000000.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/radar_iden_1000000.jpg -------------------------------------------------------------------------------- /static/images/radar_plot.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/radar_plot.jpg -------------------------------------------------------------------------------- /static/images/regen-200.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/regen-200.jpg -------------------------------------------------------------------------------- /static/images/regen-40.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/regen-40.jpg -------------------------------------------------------------------------------- /static/images/regen_coco_clip.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/regen_coco_clip.jpg -------------------------------------------------------------------------------- /static/images/regen_coco_psnr.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/regen_coco_psnr.jpg -------------------------------------------------------------------------------- /static/images/regen_diff_clip.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/regen_diff_clip.jpg -------------------------------------------------------------------------------- /static/images/regen_diff_psnr.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/regen_diff_psnr.jpg -------------------------------------------------------------------------------- /static/images/regen_vae.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/regen_vae.jpg -------------------------------------------------------------------------------- /static/images/spec_adv_unwm_wm.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/spec_adv_unwm_wm.jpg -------------------------------------------------------------------------------- /static/images/spec_adv_wm1_wm2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/spec_adv_wm1_wm2.jpg -------------------------------------------------------------------------------- /static/images/tree-ring-heatmap-old.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/tree-ring-heatmap-old.jpg -------------------------------------------------------------------------------- /static/images/unattacked-tree_ring.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/unattacked-tree_ring.jpg -------------------------------------------------------------------------------- /static/images/violin.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/violin.jpg -------------------------------------------------------------------------------- /static/images/waves.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/waves.jpg -------------------------------------------------------------------------------- /static/images/waves_small.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/waves_small.png -------------------------------------------------------------------------------- /static/images/workflow.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/workflow.jpg -------------------------------------------------------------------------------- /static/images/workflow_a.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/workflow_a.gif -------------------------------------------------------------------------------- /static/images/workflow_a.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/workflow_a.jpg -------------------------------------------------------------------------------- /static/images/workflow_b.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/workflow_b.gif -------------------------------------------------------------------------------- /static/images/workflow_b.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/static/images/workflow_b.jpg -------------------------------------------------------------------------------- /tree_ring/__init__.py: -------------------------------------------------------------------------------- 1 | from .guided_diffusion import ( 2 | generate_guided_tree_ring_message, 3 | generate_guided_tree_ring_key, 4 | guided_ddim_sample_with_tree_ring, 5 | detect_guided_tree_ring, 6 | ) 7 | from .stable_diffusion import InversableStableDiffusionPipeline 8 | from .data_utils import load_tree_ring_guided 9 | -------------------------------------------------------------------------------- /tree_ring/data_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torchvision.datasets import ImageFolder 4 | from utils import to_tensor 5 | 6 | 7 | class MultiLabelImageFolder(ImageFolder): 8 | def __init__(self, root, transform=None, target_transform=None): 9 | super(MultiLabelImageFolder, self).__init__( 10 | root, transform=transform, target_transform=target_transform 11 | ) 12 | 13 | # Detect the number of labels 14 | example_dir = os.listdir(root)[0] 15 | num_labels = example_dir.count("_") + 1 16 | 17 | label_sets = [set() for _ in range(num_labels)] 18 | 19 | # Collect all unique labels for each position 20 | for d in os.listdir(root): 21 | labels = d.split("_") 22 | for i, label in enumerate(labels): 23 | label_sets[i].add(label) 24 | 25 | # Create label mappings for each label set 26 | self.label_to_idx = [ 27 | {label: idx for idx, label in enumerate(sorted(label_set))} 28 | for label_set in label_sets 29 | ] 30 | 31 | def __getitem__(self, index): 32 | path, _ = self.samples[index] 33 | sample = self.loader(path) 34 | 35 | # Extract the multi-labels from the folder name 36 | dirname = os.path.basename(os.path.dirname(path)) 37 | labels = dirname.split("_") 38 | 39 | label_idxs = [ 40 | label_map[label] for label, label_map in zip(labels, self.label_to_idx) 41 | ] 42 | 43 | if self.transform is not None: 44 | sample = self.transform(sample) 45 | 46 | if self.target_transform is not None: 47 | label_idxs = self.target_transform(label_idxs) 48 | 49 | return sample, tuple(label_idxs) 50 | 51 | 52 | def load_tree_ring_guided( 53 | image_size, 54 | dataset_template, 55 | num_key_seeds, 56 | num_message_seeds, 57 | convert_to_tensor=True, 58 | norm_type="naive", 59 | ): 60 | assert image_size in [64, 256] 61 | assert dataset_template in ["Tiny-ImageNet", "Imagenette"] 62 | # Load WordNet IDs and class names 63 | wnid_to_words = {} 64 | with open("./datasets/tiny-imagenet-200/words.txt", "r") as f: 65 | for line in f.readlines(): 66 | wnid, words = line.strip().split("\t") 67 | wnid_to_words[wnid] = words 68 | # Load dataset 69 | data_dir = f"./datasets/tree_ring_guided_{image_size}_{dataset_template.lower()}_{num_key_seeds}k_{num_message_seeds}m" 70 | 71 | dataset = MultiLabelImageFolder( 72 | f"{data_dir}/train", 73 | lambda x: to_tensor([x], norm_type=norm_type) if convert_to_tensor else x, 74 | ) 75 | # ImageNet class names 76 | class_names = [ 77 | wnid_to_words[class_names.split("_")[0]] for class_names in dataset.classes 78 | ] 79 | # Load keys and messages 80 | keys = torch.load(f"{data_dir}/keys.pt") 81 | messages = torch.load(f"{data_dir}/messages.pt") 82 | # Check sizes 83 | assert len(set(label[0] for _, label in dataset)) == len(class_names) 84 | assert len(set(label[1] for _, label in dataset)) == len(keys) 85 | assert len(set(label[2] for _, label in dataset)) == len(messages) 86 | 87 | return dataset, class_names, keys, messages 88 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .image_utils import ( 2 | normalize_tensor, 3 | unnormalize_tensor, 4 | to_tensor, 5 | to_pil, 6 | renormalize_tensor, 7 | ) 8 | from .data_utils import ( 9 | get_imagenet_class_names, 10 | get_imagenet_wnids, 11 | load_imagenet_subset, 12 | sample_train_and_test_sets, 13 | load_imagenet_guided, 14 | sample_images_by_label_cond, 15 | sample_images_by_label_set, 16 | ) 17 | from .vis_utils import ( 18 | visualize_image_grid, 19 | visualize_image_list, 20 | visualize_supervised_dataset, 21 | save_figure_to_file, 22 | save_figure_to_buffer, 23 | save_figure_to_pil, 24 | concatenate_images, 25 | make_gif, 26 | ) 27 | from .exp_utils import set_random_seed 28 | from .io_utils import ( 29 | tuples_to_lists, 30 | lists_to_tuples, 31 | format_mean_and_std, 32 | format_mean_and_std_list, 33 | ) 34 | -------------------------------------------------------------------------------- /utils/exp_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import random 4 | 5 | 6 | # Set random seed for reproducibility 7 | def set_random_seed(seed=0): 8 | torch.manual_seed(seed + 0) 9 | torch.cuda.manual_seed(seed + 1) 10 | torch.cuda.manual_seed_all(seed + 2) 11 | np.random.seed(seed + 3) 12 | torch.cuda.manual_seed_all(seed + 4) 13 | random.seed(seed + 5) 14 | -------------------------------------------------------------------------------- /utils/image_utils.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import torch 3 | from torchvision import transforms 4 | 5 | 6 | # Normalize image tensors 7 | def normalize_tensor(images, norm_type): 8 | assert norm_type in ["imagenet", "naive"] 9 | # Two possible normalization conventions 10 | if norm_type == "imagenet": 11 | mean = [0.485, 0.456, 0.406] 12 | std = [0.229, 0.224, 0.225] 13 | normalize = transforms.Normalize(mean, std) 14 | elif norm_type == "naive": 15 | mean = [0.5, 0.5, 0.5] 16 | std = [0.5, 0.5, 0.5] 17 | normalize = transforms.Normalize(mean, std) 18 | else: 19 | assert False 20 | return torch.stack([normalize(image) for image in images]) 21 | 22 | 23 | # Unnormalize image tensors 24 | def unnormalize_tensor(images, norm_type): 25 | assert norm_type in ["imagenet", "naive"] 26 | # Two possible normalization conventions 27 | if norm_type == "imagenet": 28 | mean = [0.485, 0.456, 0.406] 29 | std = [0.229, 0.224, 0.225] 30 | unnormalize = transforms.Normalize( 31 | (-mean[0] / std[0], -mean[1] / std[1], -mean[2] / std[2]), 32 | (1 / std[0], 1 / std[1], 1 / std[2]), 33 | ) 34 | elif norm_type == "naive": 35 | mean = [0.5, 0.5, 0.5] 36 | std = [0.5, 0.5, 0.5] 37 | unnormalize = transforms.Normalize( 38 | (-mean[0] / std[0], -mean[1] / std[1], -mean[2] / std[2]), 39 | (1 / std[0], 1 / std[1], 1 / std[2]), 40 | ) 41 | else: 42 | assert False 43 | return torch.stack([unnormalize(image) for image in images]) 44 | 45 | 46 | # Convert PIL images to tensors and normalize 47 | def to_tensor(images, norm_type="naive"): 48 | assert isinstance(images, list) and all( 49 | [isinstance(image, Image.Image) for image in images] 50 | ) 51 | images = torch.stack([transforms.ToTensor()(image) for image in images]) 52 | if norm_type is not None: 53 | images = normalize_tensor(images, norm_type) 54 | return images 55 | 56 | 57 | # Unnormalize tensors and convert to PIL images 58 | def to_pil(images, norm_type="naive"): 59 | assert isinstance(images, torch.Tensor) 60 | if norm_type is not None: 61 | images = unnormalize_tensor(images, norm_type).clamp(0, 1) 62 | return [transforms.ToPILImage()(image) for image in images.cpu()] 63 | 64 | 65 | # Renormalize image tensors 66 | def renormalize_tensor(images, in_norm_type=None, out_norm_type=None): 67 | assert in_norm_type in ["imagenet", "naive"] 68 | assert out_norm_type in ["imagenet", "naive"] 69 | 70 | # First unnormalize the tensor using the input normalization type 71 | images = unnormalize_tensor(images, in_norm_type) 72 | 73 | # Then normalize the tensor using the output normalization type 74 | images = normalize_tensor(images, out_norm_type) 75 | 76 | return images 77 | -------------------------------------------------------------------------------- /utils/io_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | 4 | # Convert a list of tuples to a tuple of lists 5 | def tuples_to_lists(data): 6 | # Ensure that all tuples in the list have the same length 7 | assert all(len(t) == len(data[0]) for t in data) 8 | # Unzips the list of tuples and converts the output to a tuple of lists 9 | return tuple(map(list, zip(*data))) 10 | 11 | 12 | # Convert a tuple of lists to a list of tuples 13 | def lists_to_tuples(data): 14 | # Ensure that all lists in the tuple have the same length 15 | assert all(len(lst) == len(data[0]) for lst in data) 16 | # Zips the tuple of lists and converts the output to a list of tuples 17 | return list(zip(*data)) 18 | 19 | 20 | # Get the order of magnitude of the smallest value in a list of values 21 | def get_max_order_of_magnitude(values): 22 | max_order_of_magnitude = -math.inf 23 | for value in values: 24 | order_of_magnitude = math.floor(math.log10(value)) 25 | if order_of_magnitude > max_order_of_magnitude: 26 | max_order_of_magnitude = order_of_magnitude 27 | return max_order_of_magnitude 28 | 29 | 30 | # Format a mean and std into a string 31 | def format_mean_and_std(mean, std, order_of_magnitude=None, style="latex"): 32 | assert style in ["ascii", "unicode", "latex"] 33 | if style == "ascii": 34 | pm_sign_func = lambda x, y: x + " +/- " + y 35 | exp_sign_func = lambda x: "* E" + str(x) 36 | elif style == "unicode": 37 | pm_sign_func = lambda x, y: x + " ± " + y 38 | exp_sign_func = lambda x: "x 10^" + str(x) 39 | else: 40 | pm_sign_func = lambda x, y: "$" + x + " \pm " + y + "$" 41 | exp_sign_func = lambda x: "$\\times 10^{" + str(x) + "}$" 42 | 43 | # Get the order of magnitude of the std 44 | if order_of_magnitude is None: 45 | order_of_magnitude = get_max_order_of_magnitude([std]) 46 | 47 | # Format std to two significant figures in scientific notation 48 | std_fmt = "{:.1e}".format(std).split("e") 49 | coef_std, exp_std = std_fmt[0], int(std_fmt[1]) 50 | 51 | # Adjust the mean's decimal point 52 | adjusted_mean = mean / (10**order_of_magnitude) 53 | coef_mean = f"{adjusted_mean:.1f}" 54 | 55 | return f"({pm_sign_func(coef_mean, coef_std)}) {exp_sign_func(exp_std)}" 56 | 57 | 58 | # Format a list of means and stds into a list of strings 59 | def format_mean_and_std_list(means, stds, order_of_magnitude=None, style="latex"): 60 | assert len(means) == len(stds) 61 | if order_of_magnitude is None: 62 | order_of_magnitude = get_max_order_of_magnitude(stds) 63 | fmt_strings = [] 64 | for mean, std in zip(means, stds): 65 | fmt_strings.append(format_mean_and_std(mean, std, order_of_magnitude, style)) 66 | return fmt_strings 67 | -------------------------------------------------------------------------------- /utils/plot_utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/umd-huang-lab/WAVES/1477635d306d6f8c77a588e6441bd4ba301a1f7a/utils/plot_utils.py --------------------------------------------------------------------------------