├── .gitignore ├── LICENSE ├── README.md ├── assets ├── cmmd.png ├── human_eval.gif └── jumpy_loss.png ├── compare-two-images.py ├── eval.py ├── image_eval ├── __init__.py ├── encoders.py ├── evaluators.py ├── improved_aesthetic_predictor.py ├── local_ab_test.py ├── model_utils.py └── pairwise_evaluators.py ├── requirements.txt ├── setup.py └── tests ├── __init__.py ├── assets ├── fortune_teller.png ├── julie.jpeg ├── julie_gn_mean0_sigma100.jpeg └── julie_gn_mean0_sigma50.jpeg └── evaluators_test.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.egg-info* 3 | build/ 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Storia AI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Evaluation for Text-to-Image Models 2 | 3 | ## What is this? 4 | 5 | **TL;DR**: A Python library providing utilities to evaluate text-to-image (T2I) models. It complements established benchmarks like [HEIM](https://crfm.stanford.edu/helm/heim/latest/) by focusing on metrics that help developers fine-tune T2I models for a particular style or concept. 6 | 7 | Running it is as easy as: 8 | ``` 9 | pip install image-eval 10 | image-eval -g -r -p -m all 11 | ``` 12 | See the [installation section](#installation) for more detailed instructions. 13 | 14 | ## Motivation 15 | 16 | Image quality is subjective, and evaluating text-to-image (T2I) models is hard. But we can't make progress without being able to *measure* progress. We need standardized and robust tooling. 17 | 18 | Training T2I models is particularly difficult because there are no metrics to inform you whether your model is converging. For instance, this is what a typical training loss looks like: 19 | 20 | 21 | 22 | The goal of this repo is to bring back measurability and help you make informed decisions when building T2I models. For instance, we discovered that using [CMMD](https://arxiv.org/abs/2312.05412) as a validation metric during training on as little as 50 images can help you gauge how much progress your model is making. The plot shows the distance between a reference set and the generated set for various checkpoints: 23 | 24 | 25 | 26 | Read more about our discoveries in the [Tips and Tricks](#tips-and-tricks) section, which we will update as we learn more. 27 | 28 | ## Evaluation Metrics 29 | 30 | ### Categories 31 | We use the following categories for our metrics, inspired by [LyCORIS](https://arxiv.org/abs/2309.14859): 32 | 33 | 1. **Fidelity**: the extent to which generated images adhere to the target concept. 34 | 2. **Pairwise similarity** between two images; in contrast, fidelity makes bulk comparisons between two datasets. 35 | 3. **Controllability**: the model’s ability to generate images that align well with text prompts. 36 | 4. **Diversity**: the variety of images that are produced from a single or a set of prompts. 37 | 5. **Image quality**: the visual appeal of the generated images (naturalness, absence of artifacts or deformations). 38 | 39 | We dared to list these aspects in the order in which we believe they can be meaningfully measured. Measuring the *fidelity* of one dataset to another is a much better-defined problem than measuring a universal and elusive *image quality*. 40 | 41 | ### Metrics 42 | Here are the metrics we currently support: 43 | 44 | | Metric name | Category | Source | 45 | |------------ | -------- | ----------- | 46 | | `centroid_similarity` | fidelity | ours | 47 | | `cmmd` | fidelity | [paper](https://arxiv.org/abs/2401.09603) | 48 | | `lpips` | pairwise similarity | [paper](https://arxiv.org/pdf/1801.03924) | 49 | | `multi_ssim` | pairwise similarity | [paper](https://ieeexplore.ieee.org/document/1292216) | 50 | | `psnr` | pairwise similarity | [paper](https://ieeexplore.ieee.org/document/1163711) | 51 | | `uiqui` | pairwise similarity | [paper](https://ieeexplore.ieee.org/document/1284395) | 52 | | `clip_score` | controllability | [paper](https://arxiv.org/abs/2104.08718) | 53 | | `image_reward` | controllability | [paper](https://arxiv.org/pdf/2304.05977.pdf) | 54 | | `human_preference_score` | controllability | [repo](https://tgxs002.github.io/align_sd_web/) | 55 | | `vendi_score` | diversity | [paper](https://arxiv.org/abs/2210.02410) | 56 | | `fid` | image quality | [paper](https://arxiv.org/pdf/1706.08500) | 57 | | `inception_score` | image quality | [paper](https://arxiv.org/abs/1606.03498) | 58 | | `aesthetic_predictor` | image quality | [repo](https://github.com/christophschuhmann/improved-aesthetic-predictor) | 59 | 60 | ### Encoders 61 | Some of the metrics above rely on image embeddings in a modular way -- even though they were originally published using CLIP embeddings, we noticed that swapping embeddings might lead to better metrics in certain cases. We allow you to mix and match the metrics above with the following: 62 | 63 | 1. **[CLIP](https://arxiv.org/abs/2103.00020)** is by far the most popular encoder. It was used by the original Stable Diffusion model. 64 | 2. **[DINOv2](https://arxiv.org/abs/2304.07193)**. Compared to CLIP, which used text-guided pretraining (aligning images against captions), DINOv2 used self-supervised learning on images alone. Its training objective maximizes agreement between different patches within the same image. It was trained on a dataset of 142M automatically curated images. 65 | 3. **[ConvNeXt V2](https://arxiv.org/abs/2301.00808)**. Similarly to DINOv2, ConvNeXt V2 did not use text-guided pretraining. It was trained on an image dataset to recover masked patches. In contrast to DINOv2 (which uses a ViT), ConvNeXt V2 uses a convolutional architecture. ConvNeXtV2 is the successor of [MAE](https://arxiv.org/abs/2111.06377) (Masked Auto Encoder) embeddings. 66 | 4. **[InsightFace](https://insightface.ai/)** is particularly good at encoding human faces, and is very effective when fine-tuning T2I models for headshots. 67 | 68 | ## GUI 69 | We also provide a simple and ready-to-use [Streamlit](https://streamlit.io/) interface for performing human evaluation of model outputs on your local machine. We recommend to use this when the automated metrics are just not discriminative enough to help you decide betwen two checkpoints. 70 | 71 | ![Human eval](assets/human_eval.gif) 72 | 73 | ## Installation 74 | 75 | This library has been tested on Python 3.10.9. To install: 76 | ``` 77 | pip install image-eval 78 | ``` 79 | 80 | Optionally, if you have a CUDA-enabled device, install the [version of PyTorch](https://pytorch.org/get-started/previous-versions/) that matches your CUDA version. For CUDA 11.3, that might look like: 81 | ``` 82 | pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113 83 | ``` 84 | 85 | ## Usage 86 | 87 | There are two ways to interact with the `image-eval` library: either through the CLI or through the API. 88 | 89 | ### CLI for automated evaluation 90 | 91 | Once you installed the library, you can invoke it through the CLI on your terminal via `image_eval `. The full list of flags is in [eval.py](eval.py), but here are the most important ones: 92 | 93 | - `-g` should point to a folder of generated images 94 | - `-r` should point to a folder of reference images 95 | - `-p` (needed for controllability metrics only) should point to a `.json` file that stores `image_filename: prompt` pairs, for instance: 96 | ``` 97 | { 98 | "image_1.jpg": "prompt for image 1", 99 | "image_2.jpg": "prompt for image 2", 100 | ... 101 | } 102 | ``` 103 | - `-m` should specify the desired metrics; it can be `all`, a certain category (e.g. `fidelity`) or a specific metric (e.g `centroid_similarity`). 104 | 105 | For example, to calculate the fidelity of a generated dataset to some reference images, you would run 106 | ``` 107 | image_eval -m fidelity -g /path/to/generated/images -r /path/to/reference/images 108 | ``` 109 | The result will look like this: 110 | ``` 111 | | Metric Name | Value | 112 | |---------------------------------+----------| 113 | | centroid_similarity_clip | 0.844501 | 114 | | centroid_similarity_dino_v2 | 0.573843 | 115 | | centroid_similarity_convnext_v2 | 0.606375 | 116 | | centroid_similarity_insightface | 0.488649 | 117 | | cmmd_clip | 0.162164 | 118 | | cmmd_dino_v2 | 0.1689 | 119 | | cmmd_convnext_v2 | 0.187492 | 120 | | cmmd_insightface | 0.169578 | 121 | ``` 122 | 123 | ### CLI for human evaluation 124 | 125 | To launch the human evaluation interface, run: 126 | ``` 127 | image_eval --local-human-eval --model-predictions-json /path/to/model_comparisons.json 128 | ``` 129 | Here `model_comparisons.json` is a JSON file with the following format: 130 | ``` 131 | [ 132 | { 133 | "model_1": "path to image 1 from model 1", 134 | "model_2": "path to image 1 from model 2", 135 | "prompt": "prompt for image 1" 136 | }, 137 | { 138 | "model_1": "path to image 2 from model 1", 139 | "model_2": "path to image 2 from model 2", 140 | "prompt": "prompt for image 2" 141 | }, 142 | ... 143 | ] 144 | ``` 145 | 146 | where `model_1` and `model_2` are the keys for the paths to image outputs for the respective models. Our library does expect the **keys to match these values exactly**. 147 | 148 | An interface should launch in your browser at `http://localhost:8501`. 149 | 150 | NOTE: When you click `Compute Model Wins` a local file named `scores.json` will be created in the directory from which you launched the CLI. 151 | 152 | ### Programmatic 153 | 154 | You can also interact with the library through the API directly. For example, to invoke the `clip_score` metric, you could do the following: 155 | ``` 156 | from image_eval.evaluators import CLIPScoreEvaluator 157 | 158 | evaluator = CLIPScoreEvaluator(device="cpu") # or "cuda" if you have a GPU-enabled device 159 | images = [np.randint(0, 255, (224, 224, 3)) for _ in range(10)] # list of 10 random images 160 | prompts = ["random prompt" * 10] 161 | evaluator.evaluate(images, prompts) 162 | ``` 163 | 164 | ## Tips and Tricks 165 | In this section, we will share our tips on how to use existing metrics to bring some rigor to the art of fine-tuning T2I models. Please don't take anything as an absolute truth. If we knew things for sure, we would be writing a paper instead of a Github README. Suggestions and constructive feedback are more than welcome! 166 | 167 | [TODO] 168 | 169 | ## Contributing 170 | 171 | We welcome any and all contributions to this library, as well as discussions on how we can make the art of training T2I models more scientific. 172 | 173 | To add a new **metric**, all you need to do is create a new class that inherits from the `BaseEvaluator` and implements the `evaluate` method. For examples of how our current metrics implement this contract, see `evaluators.py`. 174 | 175 | To add a new **encoder**, simply implement the `BaseEncoder` interface (see `encoders.py`). 176 | 177 | ## Other Resources 178 | Here are other notable resources for evaluating T2I models: 179 | - [HEIM](https://crfm.stanford.edu/helm/heim/latest/) 180 | - [T2I-Compbench](https://github.com/Karine-Huang/T2I-CompBench) 181 | -------------------------------------------------------------------------------- /assets/cmmd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Storia-AI/image-eval/a6c5dd5cf525e840571c5d4eaf7d824c49ea3b25/assets/cmmd.png -------------------------------------------------------------------------------- /assets/human_eval.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Storia-AI/image-eval/a6c5dd5cf525e840571c5d4eaf7d824c49ea3b25/assets/human_eval.gif -------------------------------------------------------------------------------- /assets/jumpy_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Storia-AI/image-eval/a6c5dd5cf525e840571c5d4eaf7d824c49ea3b25/assets/jumpy_loss.png -------------------------------------------------------------------------------- /compare-two-images.py: -------------------------------------------------------------------------------- 1 | """Compares the quality of two images.""" 2 | import argparse 3 | import torch 4 | 5 | from PIL import Image 6 | from tabulate import tabulate 7 | 8 | from image_eval.evaluators import EvaluatorType 9 | from image_eval.evaluators import get_evaluators_for_type 10 | 11 | 12 | def read_args(): 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--image1", "-i1", help="Path to the first image", type=str, required=True) 15 | parser.add_argument("--image2", "-i2", help="Path to the second image", type=str, required=True) 16 | parser.add_argument("--prompt", "-p", help="Prompt used to generated the images", type=str) 17 | return parser.parse_args() 18 | 19 | def main(): 20 | args = read_args() 21 | image1 = Image.open(args.image1).convert("RGB") 22 | image2 = Image.open(args.image2).convert("RGB") 23 | device = "cuda" if torch.cuda.is_available() else "cpu" 24 | 25 | individual_metrics = [] 26 | for evaluator_cls in get_evaluators_for_type(EvaluatorType.IMAGE_QUALITY): 27 | evaluator = evaluator_cls(device) 28 | if not evaluator.is_useful([image1]): 29 | continue 30 | results1 = evaluator.evaluate(generated_images=[image1]) 31 | results2 = evaluator.evaluate(generated_images=[image2]) 32 | for metric, value1 in results1.items(): 33 | value2 = results2[metric] 34 | individual_metrics.append((metric, value1, value2)) 35 | 36 | if args.prompt: 37 | for evaluator_cls in get_evaluators_for_type(EvaluatorType.CONTROLLABILITY): 38 | evaluator = evaluator_cls(device) 39 | if not evaluator.is_useful([image1]): 40 | continue 41 | results1 = evaluator.evaluate(generated_images=[image1], prompts=[args.prompt]) 42 | results2 = evaluator.evaluate(generated_images=[image2], prompts=[args.prompt]) 43 | for metric, value1 in results1.items(): 44 | value2 = results2[metric] 45 | individual_metrics.append((metric, value1, value2)) 46 | 47 | print(tabulate(individual_metrics, 48 | headers=["Metric Name", "Image 1", "Image 2"], 49 | tablefmt="orgtbl")) 50 | 51 | fidelity_metrics = [] 52 | for evaluator_cls in get_evaluators_for_type(EvaluatorType.FIDELITY): 53 | evaluator = evaluator_cls(device) 54 | if not evaluator.is_useful([image1]): 55 | continue 56 | results = evaluator.evaluate(generated_images=[image1], real_images=[image2]) 57 | fidelity_metrics.extend([(metric, value) for metric, value in results.items()]) 58 | 59 | print(tabulate(fidelity_metrics, 60 | headers=["Metric Name", "Fidelity Score"], 61 | tablefmt="orgtbl")) 62 | 63 | 64 | if __name__ == "__main__": 65 | # If we don't explicitly mark all models for inference, Huggingface seems to hold on to some 66 | # object references even after they're not needed anymore (perhaps to keep gradients around), 67 | # which causes this script to OOM when multiple evaluators are run in a sequence. 68 | # See https://github.com/huggingface/transformers/issues/26275. 69 | with torch.inference_mode(): 70 | main() 71 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import logging 4 | import os 5 | import sys 6 | 7 | import torch 8 | from PIL import Image 9 | from tabulate import tabulate 10 | 11 | from image_eval.evaluators import ( 12 | AestheticPredictorEvaluator, 13 | CLIPScoreEvaluator, 14 | CMMDEvaluator, 15 | EvaluatorType, 16 | FIDEvaluator, 17 | HumanPreferenceScoreEvaluator, 18 | ImageRewardEvaluator, 19 | InceptionScoreEvaluator, 20 | CentroidSimilarityEvaluator, 21 | VendiScoreEvaluator 22 | ) 23 | from image_eval.pairwise_evaluators import ( 24 | LPIPSEvaluator, 25 | MultiSSIMEvaluator, 26 | PSNREvaluator, 27 | UIQIEvaluator 28 | ) 29 | 30 | from streamlit.web import cli as stcli 31 | from typing import Dict 32 | 33 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') 34 | 35 | METRIC_NAME_TO_EVALUATOR = { 36 | "clip_score": { 37 | "evaluator": CLIPScoreEvaluator, 38 | "description": "This metric corresponds to the cosine similarity between the visual CLIP embedding for " 39 | "an image and the textual CLIP embedding for a caption. The score is bound between 0 and " 40 | "100 with 100 being the best score. For more info, check out https://arxiv.org/abs/2104.08718" 41 | }, 42 | "centroid_similarity": { 43 | "evaluator": CentroidSimilarityEvaluator, 44 | "description": "This metric reflects the average cosine similarity between the cluster center of reference images " 45 | "and the generated images. The score is bound between 0 and 1, with 1 being the best score. " 46 | "The purpose of the metric is to measure how in-style generated images are, compared to the real ones." 47 | }, 48 | "inception_score": { 49 | "evaluator": InceptionScoreEvaluator, 50 | "description": "This metrics uses the Inception V3 model to compute class probabilities for generated images. " 51 | "and then calculates the KL divergence between the marginal distribution of the class " 52 | "probabilities and the conditional distribution of the class probabilities given the generated " 53 | "images. The score is bound between 1 and the number of classes supported by the classification " 54 | "model. The score is computed on random splits of the data so both a mean and standard deviation " 55 | "are reported. For more info, check out https://arxiv.org/abs/1606.03498" 56 | }, 57 | "fid": { 58 | "evaluator": FIDEvaluator, 59 | "description": "This metrics uses the Inception V3 model to compute a multivariate Gaussian for a set of real images" 60 | "as well as a multivariate gaussian for a set of fake images. A distance is then computed using " 61 | "the summary statistics of these Gaussians. A lower score is better with a 0 being a perfect " 62 | "score indicating identical groups of images. This metric computes a distance for features" 63 | "derived from the 64, 192, 768, and 2048 feature layers. For more info, check out https://arxiv.org/abs/1512.00567" 64 | }, 65 | "cmmd": { 66 | "evaluator": CMMDEvaluator, 67 | "description": "A better FID alternative. See https://arxiv.org/abs/2401.09603.", 68 | }, 69 | "aesthetic_predictor": { 70 | "evaluator": AestheticPredictorEvaluator, 71 | "description": "This metrics trains a model to predict an aesthetic score using a multilayer perceptron" 72 | "trained from the AVA dataset (http://refbase.cvc.uab.es/files/MMP2012a.pdf) using CLIP input embeddings." 73 | "A larger score indicates a better model." 74 | }, 75 | "image_reward": { 76 | "evaluator": ImageRewardEvaluator, 77 | "description": "This metrics trains a model to predict image rewards using a dataset of human preferences for images." 78 | "Each reward is intended to output a value sampled from a Gaussian with 0 mean and stddev 1. For more details" 79 | "check out https://arxiv.org/pdf/2304.05977.pdf" 80 | }, 81 | "human_preference_score": { 82 | "evaluator": HumanPreferenceScoreEvaluator, 83 | "description": "This metric outputs an estimate of the human preference for an image based on the paper https://tgxs002.github.io/align_sd_web/" 84 | "The metric is bound between -100 and 100 with 100 being the best score." 85 | }, 86 | "vendi_score": { 87 | "evaluator": VendiScoreEvaluator, 88 | "description": "This metric evaluates how diverse the generated image set is. We suggest generating all images with the same prompt." 89 | "See https://arxiv.org/abs/2210.02410.", 90 | }, 91 | "lpips_score": { 92 | "evaluator": LPIPSEvaluator, 93 | "description": "Calculates Learned Perceptual Image Patch Similarity (LPIPS) score as the distance between the " 94 | "activations of two image patches for some deep network (VGG). The score is between 0 and 1, where 0 means the " 95 | "images are identical. See https://arxiv.org/pdf/1801.03924.", 96 | }, 97 | "multi_ssime_score": { 98 | "evaluator": MultiSSIMEvaluator, 99 | "description": "Calculates Multi-scale Structural Similarity Index Measure (SSIM). This is an extension of " 100 | "SSIM, which assesses the similarity between two images based on three components: luminance, contrast, and " 101 | "structure. The score is between -1 and 1, where 1 = perfect similarity, 0 = no similarity and " 102 | "-1 = perfect anti-corelation. See https://ieeexplore.ieee.org/document/1292216.", 103 | }, 104 | "psnr_score": { 105 | "evaluator": PSNREvaluator, 106 | "description": "Calculates Peak Signal-to-Noise Ratio (PSNR). It was originally designed to measure the " 107 | "quality of reconstructed or compressed images compared to their original versions. Its values are between " 108 | "-infinity and +infinity, where identical images score +infinity. See " 109 | "https://ieeexplore.ieee.org/document/1163711.", 110 | }, 111 | "uiqi_score": { 112 | "evaluator": UIQIEvaluator, 113 | "description": "Calculates Universal Image Quality Index (UIQI). Based on the idea of comparing statistical " 114 | "properties of an original and a distorted image in both the spatial and frequency domains. The calculation " 115 | "involves several steps, including the computation of mean, variance, and covariance of the pixel values in " 116 | "local windows of the images. It also considers factors like luminance, contrast, and structure. See " 117 | "https://ieeexplore.ieee.org/document/1284395.", 118 | }, 119 | } 120 | 121 | 122 | def read_args(): 123 | parser = argparse.ArgumentParser() 124 | parser.add_argument("--metrics", "-m", 125 | default="all", 126 | help="valid values are: (1) all, (2) one of the following categories: image_quality, " 127 | "controllability, fidelity, pairwise_similarity, diversity, or (3) a comma-separated list " 128 | "of metric names, for example: clip_score,style_similarity.", 129 | type=str) 130 | parser.add_argument("--generated-images", "-g", 131 | help="path to directory containing generated images to evaluate; for pairwise_similarity, " 132 | "-g and -r need to have the same number of images with the same filenames.", 133 | type=str, 134 | required=True) 135 | parser.add_argument("--real-images", "-r", 136 | help="path to directory containing real images to use for evaluation; for pairwise_similarity, " 137 | "-g and -r need to have the same number of images with the same filenames.", 138 | type=str) 139 | parser.add_argument("--prompts", "-p", help="path to file containing mapping from image to associated prompt", 140 | type=str) 141 | parser.add_argument("--available-metrics", 142 | help="description of all the available metrics and a synopsis of their properties", 143 | action="store_true") 144 | parser.add_argument("--local-human-eval", 145 | help="run a local instance of streamlit for human evaluation", 146 | action="store_true") 147 | parser.add_argument("--model-predictions-json", 148 | help="path to json file containing model predictions") 149 | parser.add_argument("--aesthetic-predictor-model-url", 150 | default=AestheticPredictorEvaluator.DEFAULT_URL, 151 | help="Model checkpoint for the aesthetic predictor evaluator.") 152 | parser.add_argument("--human-preference-score-model-url", 153 | default=HumanPreferenceScoreEvaluator.DEFAULT_URL, 154 | help="Model checkpoint for the human preference score evaluator.") 155 | return parser.parse_args() 156 | 157 | 158 | def read_images(image_dir: str) -> Dict[str, Image.Image]: 159 | """Reads all the images in a given folder.""" 160 | images = {} 161 | # It's important to sort the filenames for pairwise similarity metrics. 162 | image_filenames = sorted(os.listdir(image_dir)) 163 | for image_filename in image_filenames: 164 | image_path = os.path.join(image_dir, image_filename) 165 | try: 166 | pil_image = Image.open(image_path).convert("RGB") 167 | except Exception: 168 | # Ignore non-image files in this folder. 169 | logging.warning(f"Cannot read image from {image_path}. Skipping.") 170 | continue 171 | images[image_filename] = pil_image 172 | return images 173 | 174 | 175 | def read_prompts_for_images(images_by_filename: Dict[str, Image.Image], prompts_path: str): 176 | images = [] 177 | prompts = [] 178 | 179 | with open(prompts_path, "r") as f: 180 | prompts_by_image_filename = json.load(f) 181 | # It's important to sort the filenames for pairwise similarity metrics. 182 | prompts_by_image_filename = sorted(prompts_by_image_filename.items(), key=lambda x: x[0]) 183 | 184 | for image_filename, prompt in prompts_by_image_filename: 185 | image = images_by_filename.get(image_filename) 186 | if image: 187 | images.append(image) 188 | prompts.append(prompt) 189 | else: 190 | logging.warning(f"Could not find image {image_filename}. " 191 | f"Available images are: {images_by_filename.keys()}") 192 | return images, prompts 193 | 194 | 195 | def main(): 196 | args = read_args() 197 | if args.available_metrics: 198 | metric_descr = [] 199 | for metric, info in METRIC_NAME_TO_EVALUATOR.items(): 200 | metric_descr.append([metric, info["description"]]) 201 | print(tabulate(metric_descr, headers=["Metric Name", "Description"], tablefmt="grid")) 202 | return 203 | 204 | if args.local_human_eval: 205 | assert args.model_predictions_json is not None, "Must provide model predictions json" 206 | lib_folder = os.path.dirname(os.path.realpath(__file__)) 207 | sys.argv = ["streamlit", "run", f"{lib_folder}/image_eval/local_ab_test.py", "--", "--model-predictions-json", args.model_predictions_json] 208 | sys.exit(stcli.main()) 209 | 210 | metrics_explicitly_specified = [] 211 | if args.metrics == "all": 212 | # We exclude PAIRWISE_SIMILARITY metrics from here and only calculate them when explicilty 213 | # specified by the user, as they require aligned datasets. 214 | metrics = [name for name, metric in METRIC_NAME_TO_EVALUATOR.items() 215 | if metric["evaluator"].TYPE != EvaluatorType.PAIRWISE_SIMILARITY] 216 | args.metrics = ",".join(metrics) 217 | elif args.metrics in [member.name.lower() for member in EvaluatorType]: 218 | args.metrics = ",".join([ 219 | name for name, metric in METRIC_NAME_TO_EVALUATOR.items() 220 | if metric["evaluator"].TYPE.name.lower() == args.metrics 221 | ]) 222 | else: 223 | # Metrics must be comma-separated 224 | metrics_explicitly_specified = args.metrics.split(",") 225 | metrics = args.metrics.split(",") 226 | 227 | generated_images_by_filename = read_images(args.generated_images) 228 | real_images = list(read_images(args.real_images).values()) if args.real_images else None 229 | 230 | if args.prompts: 231 | generated_images, prompts = read_prompts_for_images(generated_images_by_filename, args.prompts) 232 | else: 233 | generated_images = list(generated_images_by_filename.values()) 234 | prompts = None 235 | 236 | computed_metrics = [] 237 | device = "cuda" if torch.cuda.is_available() else "cpu" 238 | logging.info(f"Running evaluation on device: {device}") 239 | 240 | # Compute all metrics 241 | all_computed_metrics = {} 242 | for metric in metrics: 243 | if metric == "aesthetic_predictor": 244 | evaluator = AestheticPredictorEvaluator(device, args.aesthetic_predictor_model_url) 245 | elif metric == "human_preference_score": 246 | evaluator = HumanPreferenceScoreEvaluator(device, args.human_preference_score_model_url) 247 | else: 248 | metric_evaluator = METRIC_NAME_TO_EVALUATOR[metric]["evaluator"] 249 | evaluator = metric_evaluator(device) 250 | 251 | if (not evaluator.should_trigger_for_data(generated_images) and 252 | not metric in metrics_explicitly_specified): 253 | logging.warning(f"Skipping metric {metric} as it is not useful for the given images.") 254 | continue 255 | 256 | logging.info(f"Computing metric {metric}...") 257 | computed_metrics = evaluator.evaluate(generated_images, 258 | real_images=real_images, 259 | prompts=prompts) 260 | all_computed_metrics.update(computed_metrics) 261 | 262 | # Print all results 263 | print(tabulate(all_computed_metrics.items(), 264 | headers=["Metric Name", "Value"], 265 | tablefmt="orgtbl")) 266 | 267 | 268 | if __name__ == "__main__": 269 | # If we don't explicitly mark all models for inference, Huggingface seems to hold on to some 270 | # object references even after they're not needed anymore (perhaps to keep gradients around), 271 | # which causes this script to OOM when multiple evaluators are run in a sequence. 272 | # See https://github.com/huggingface/transformers/issues/26275. 273 | with torch.inference_mode(): 274 | main() 275 | -------------------------------------------------------------------------------- /image_eval/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Storia-AI/image-eval/a6c5dd5cf525e840571c5d4eaf7d824c49ea3b25/image_eval/__init__.py -------------------------------------------------------------------------------- /image_eval/encoders.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from contextlib import redirect_stdout 3 | import logging 4 | import numpy as np 5 | import os 6 | import torch 7 | from io import StringIO 8 | 9 | from PIL import Image 10 | from insightface.app import FaceAnalysis 11 | from transformers import AutoImageProcessor 12 | from transformers import Dinov2Model 13 | from transformers import CLIPModel 14 | from transformers import CLIPProcessor 15 | from transformers import ConvNextV2Model 16 | 17 | # When HuggingFace is down, use the cache and don't make any calls to them. 18 | LOCAL_FILES_ONLY = os.getenv("LOCAL_FILES_ONLY", False) 19 | 20 | 21 | class BaseEncoder(abc.ABC): 22 | """A model that maps images from pixel space to an embedding.""" 23 | def __init__(self, id: str, device: str): 24 | self.id = id 25 | self.device = device 26 | 27 | @abc.abstractmethod 28 | def encode(self, images: list[Image.Image]): 29 | pass 30 | 31 | 32 | class CLIPEncoder(BaseEncoder): 33 | """ 34 | Original paper: https://arxiv.org/abs/2103.00020 (published Feb 2021). 35 | Used by LyCORIS for evaluation: https://arxiv.org/pdf/2309.14859.pdf. 36 | """ 37 | def __init__(self, device: str): 38 | super().__init__("clip", device) 39 | model_name = "openai/clip-vit-base-patch16" 40 | self.model = CLIPModel.from_pretrained(model_name, local_files_only=LOCAL_FILES_ONLY)\ 41 | .to(self.device) 42 | self.processor = CLIPProcessor.from_pretrained(model_name) 43 | 44 | def encode(self, images: list[Image.Image]): 45 | image_inputs = self.processor(text=None, images=images, return_tensors="pt", padding=True) 46 | image_inputs = {k: v.to(self.device) for k, v in image_inputs.items()} 47 | return self.model.get_image_features(**image_inputs) 48 | 49 | 50 | class DinoV2Encoder(BaseEncoder): 51 | """ 52 | Original paper: https://arxiv.org/abs/2304.07193 (published April 2023). 53 | Used by LyCORIS for evaluation: https://arxiv.org/pdf/2309.14859.pdf. 54 | 55 | Compared to CLIP, which used text-guided pretraining (aligning images against captions), DinoV2 56 | used self-supervised learning on images alone. Its training objective maximizes agreement 57 | between different patches within the same image. The intuition behind not relying on captions is 58 | that it enables the model to pay attention to finer details, not just the ones captured in text. 59 | DinoV2 was trained on a dataset of 142M automatically curated images. 60 | """ 61 | def __init__(self, device: str): 62 | super().__init__("dino_v2", device) 63 | model_name = "facebook/dinov2-base" 64 | self.model = Dinov2Model.from_pretrained(model_name, local_files_only=LOCAL_FILES_ONLY)\ 65 | .to(self.device) 66 | self.processor = AutoImageProcessor.from_pretrained(model_name) 67 | 68 | def encode(self, images: list[Image.Image]): 69 | image_inputs = self.processor(text=None, images=images, return_tensors="pt", padding=True) 70 | image_inputs = {k: v.to(self.device) for k, v in image_inputs.items()} 71 | return self.model(**image_inputs).pooler_output 72 | 73 | 74 | class ConvNeXtV2Encoder(BaseEncoder): 75 | """ 76 | Original paper: https://arxiv.org/abs/2301.00808 (published Jan 2023). 77 | Used by LyCORIS for evaluation: https://arxiv.org/pdf/2309.14859.pdf. 78 | 79 | Similarly to DinoV2, ConvNeXtV2 did not use text-guided pretraining. It was trained on an image 80 | dataset to recover masked patches. Compared to DinoV2 (which uses a ViT = Visual Transformer), 81 | ConvNeXtV2 used a convolutional architecture. ConvNeXtV2 is the successor of MAE (Masked Auto 82 | Encoder) embeddings, also coming out of Meta (https://arxiv.org/abs/2111.06377). 83 | """ 84 | def __init__(self, device: str): 85 | super().__init__("convnext_v2", device) 86 | model_name = "facebook/convnextv2-base-22k-384" 87 | self.model = ConvNextV2Model.from_pretrained(model_name, local_files_only=LOCAL_FILES_ONLY)\ 88 | .to(self.device) 89 | self.processor = AutoImageProcessor.from_pretrained(model_name) 90 | 91 | def encode(self, images: list[Image.Image]): 92 | image_inputs = self.processor(text=None, images=images, return_tensors="pt", padding=True) 93 | image_inputs = {k: v.to(self.device) for k, v in image_inputs.items()} 94 | return self.model(**image_inputs).pooler_output 95 | 96 | 97 | class InsightFaceEncoder(BaseEncoder): 98 | """Computes face embeddings; see https://insightface.ai/.""" 99 | EMBEDDING_SIZE = 512 100 | 101 | def __init__(self, device: str): 102 | super().__init__("insightface", device) 103 | provider = "CUDAExecutionProvider" if "cuda" in device else "CPUExecutionProvider" 104 | # The `insightface` library is very verbose, so we need to silence it. 105 | with redirect_stdout(StringIO()): 106 | self.app = FaceAnalysis(providers=[provider]) 107 | 108 | def encode(self, images: list[Image.Image]): 109 | with redirect_stdout(StringIO()): 110 | self.app.prepare(ctx_id=0, det_size=images[0].size) 111 | all_embeddings = [] 112 | for image in images: 113 | try: 114 | # Returns one result for each person identified in the image. 115 | results = self.app.get(np.array(image)) 116 | embeddings = [r["embedding"] for r in results] 117 | except Exception as e: 118 | logging.warning(f"The `insightface` library failed to extract embeddings: {e}") 119 | embeddings = [] 120 | 121 | if not embeddings: 122 | embeddings.append(np.zeros(self.EMBEDDING_SIZE)) 123 | all_embeddings.append(np.mean(np.array(embeddings), axis=0)) 124 | 125 | all_embeddings = np.stack(all_embeddings, axis=0) 126 | return torch.tensor(all_embeddings, dtype=torch.float32).to(self.device) 127 | 128 | 129 | ALL_ENCODER_CLASSES = [CLIPEncoder, DinoV2Encoder, ConvNeXtV2Encoder, InsightFaceEncoder] 130 | -------------------------------------------------------------------------------- /image_eval/evaluators.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | import ImageReward as RM 4 | import clip 5 | import torch 6 | from PIL import Image 7 | from enum import Enum 8 | from torchmetrics.image.fid import FrechetInceptionDistance 9 | from torchmetrics.image.inception import InceptionScore 10 | from torchmetrics.multimodal.clip_score import CLIPScore 11 | from torchvision import transforms 12 | from typing import Dict 13 | from vendi_score import vendi 14 | 15 | from image_eval.improved_aesthetic_predictor import run_inference 16 | from image_eval.encoders import ALL_ENCODER_CLASSES 17 | from image_eval.model_utils import download_model 18 | 19 | torch.manual_seed(42) 20 | 21 | 22 | # TODO (mihail): Decouple this so not all evaluators are in the same file 23 | 24 | 25 | class EvaluatorType(Enum): 26 | """Follows the evaluation terminology established by https://arxiv.org/pdf/2309.14859.pdf.""" 27 | # The visual appeal of the generated images (naturalness, absence of artifacts or deformations). 28 | IMAGE_QUALITY = 1 29 | 30 | # The model’s ability to generate images that align well with text prompts. 31 | CONTROLLABILITY = 2 32 | 33 | # The extent to which generated images adhere to the target concept. Relevant for fine-tuned 34 | # models rather than foundational ones (e.g. vanilla Stable Diffusion). 35 | FIDELITY = 3 36 | 37 | # The pairwise similarity between two sets of images. While FIDELITY allows the two sets to have 38 | # different sizes and makes bulk comparisons, PAIRWISE_SIMILARITY requires a 1:1 correspondence 39 | # between images and makes pairwise comparisons. 40 | PAIRWISE_SIMILARITY = 4 41 | 42 | # The variety of images that are produced from a single or a set of prompts. 43 | DIVERSITY = 5 44 | 45 | 46 | class BaseEvaluator(abc.ABC): 47 | """Base class for evaluators.""" 48 | def __init__(self, device: str): 49 | self.device = device 50 | 51 | @abc.abstractmethod 52 | def evaluate(self, generated_images: list[Image.Image], *args, **kwargs) -> Dict[str, float]: 53 | pass 54 | 55 | def should_trigger_for_data(self, generated_images: list[Image.Image], *args, **kwargs) -> bool: 56 | return True 57 | 58 | 59 | class CLIPScoreEvaluator(BaseEvaluator): 60 | TYPE = EvaluatorType.CONTROLLABILITY 61 | HIGHER_IS_BETTER = True 62 | 63 | def __init__(self, device: str): 64 | super().__init__(device) 65 | self.evaluator = CLIPScore(model_name_or_path="openai/clip-vit-base-patch16").to(self.device) 66 | 67 | def evaluate(self, 68 | generated_images: list[Image.Image], 69 | prompts: list[str], 70 | **ignored_kwargs) -> Dict[str, float]: 71 | torch_imgs = [transforms.ToTensor()(img).to(self.device) for img in generated_images] 72 | self.evaluator.update(torch_imgs, prompts) 73 | return {"clip_score": self.evaluator.compute().item()} 74 | 75 | 76 | class CentroidSimilarityEvaluator(BaseEvaluator): 77 | TYPE = EvaluatorType.FIDELITY 78 | HIGHER_IS_BETTER = True 79 | 80 | def __init__(self, device: str): 81 | super().__init__(device) 82 | 83 | def evaluate(self, 84 | generated_images: list[Image.Image], 85 | real_images: list[Image.Image], 86 | **ignored_kwargs) -> Dict[str, float]: 87 | """Returns the average cosine similarity between the generated images and the center of the cluster defined by real images.""" 88 | cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6) 89 | results = {} 90 | 91 | for encoder_cls in ALL_ENCODER_CLASSES: 92 | encoder = encoder_cls(self.device) 93 | generated_embeddings = encoder.encode(generated_images) 94 | generated_center = torch.mean(generated_embeddings, axis=0, keepdim=True) 95 | real_embeddings = encoder.encode(real_images) 96 | real_center = torch.mean(real_embeddings, axis=0, keepdim=True) 97 | results[f"centroid_similarity_{encoder.id}"] = cos(generated_center, real_center).item() 98 | return results 99 | 100 | 101 | class InceptionScoreEvaluator(BaseEvaluator): 102 | TYPE = EvaluatorType.IMAGE_QUALITY 103 | HIGHER_IS_BETTER = True 104 | 105 | def __init__(self, device: str): 106 | super().__init__(device) 107 | self.evaluator = InceptionScore().to(self.device) 108 | 109 | def evaluate(self, generated_images: list[Image.Image], **ignored_kwargs) -> Dict[str, float]: 110 | torch_imgs = torch.stack([transforms.ToTensor()(img).to(torch.uint8).to(self.device) 111 | for img in generated_images]) 112 | self.evaluator.update(torch_imgs) 113 | mean, stddev = self.evaluator.compute() 114 | return {"inception_score_mean": mean.item(), 115 | "inception_score_stddev": stddev.item()} 116 | 117 | def should_trigger_for_data(self, generated_images: list[Image.Image], **ignored_kwargs) -> bool: 118 | # InceptionScore calculates a marginal distribution over the objects identified in the 119 | # images. This requires a large number of images to be useful; papers use 30k-60k images. 120 | # See this blog post for a high-level description of how this works: 121 | # https://medium.com/octavian-ai/a-simple-explanation-of-the-inception-score-372dff6a8c7a 122 | return len(generated_images) >= 1000 123 | 124 | 125 | class FIDEvaluator(BaseEvaluator): 126 | TYPE = EvaluatorType.IMAGE_QUALITY 127 | HIGHER_IS_BETTER = False 128 | 129 | def __init__(self, device: str): 130 | super().__init__(device) 131 | self.evaluator64 = FrechetInceptionDistance(feature=64).to(self.device).set_dtype(torch.float64) 132 | self.evaluator192 = FrechetInceptionDistance(feature=192).to(self.device).set_dtype(torch.float64) 133 | self.evaluator768 = FrechetInceptionDistance(feature=768).to(self.device).set_dtype(torch.float64) 134 | self.evaluator2048 = FrechetInceptionDistance(feature=2048).to(self.device).set_dtype(torch.float64) 135 | 136 | def evaluate(self, 137 | generated_images: list[Image.Image], 138 | real_images: list[Image.Image], 139 | **ignored_kwargs) -> Dict[str, float]: 140 | torch_gen_imgs = torch.stack([transforms.ToTensor()(img).to(torch.uint8).to(self.device) 141 | for img in generated_images]) 142 | 143 | # Real images (since they were not generated) might have various sizes. We'll resize them to the generated size. 144 | gen_size = generated_images[0].size 145 | real_images = [img.resize(gen_size) for img in real_images] 146 | torch_real_imgs = torch.stack([transforms.ToTensor()(img).to(torch.uint8).to(self.device) 147 | for img in real_images]) 148 | 149 | self.evaluator64.update(torch_gen_imgs, real=False) 150 | self.evaluator64.update(torch_real_imgs, real=True) 151 | self.evaluator192.update(torch_gen_imgs, real=False) 152 | self.evaluator192.update(torch_real_imgs, real=True) 153 | self.evaluator768.update(torch_gen_imgs, real=False) 154 | self.evaluator768.update(torch_real_imgs, real=True) 155 | self.evaluator2048.update(torch_gen_imgs, real=False) 156 | self.evaluator2048.update(torch_real_imgs, real=True) 157 | return {"fid_score_64": self.evaluator64.compute(), 158 | "fid_score_192": self.evaluator192.compute(), 159 | "fid_score_768": self.evaluator768.compute(), 160 | "fid_score_2048": self.evaluator2048.compute()} 161 | 162 | def should_trigger_for_data(self, generated_images: list[Image.Image]): 163 | # Similarly to Inception Score, FID calculates marginal distributions over datasets, and is 164 | # only useful when there are thousands of images. 165 | return len(generated_images) >= 1000 166 | 167 | 168 | class CMMDEvaluator(BaseEvaluator): 169 | """Original paper: https://arxiv.org/abs/2401.09603 (published Jan 2024). 170 | 171 | This implementation is adapted from https://github.com/sayakpaul/cmmd-pytorch/blob/main/distance.py. 172 | """ 173 | TYPE = EvaluatorType.FIDELITY 174 | HIGHER_IS_BETTER = False 175 | _SIGMA = 10 176 | 177 | def __init__(self, device: str): 178 | super().__init__(device) 179 | 180 | def evaluate(self, 181 | generated_images: list[Image.Image], 182 | real_images: list[Image.Image], 183 | **ignored_kwargs) -> Dict[str, float]: 184 | results = {} 185 | for encoder_cls in ALL_ENCODER_CLASSES: 186 | encoder = encoder_cls(self.device) 187 | x = encoder.encode(generated_images) 188 | y = encoder.encode(real_images) 189 | 190 | x_sqnorms = torch.diag(torch.matmul(x, x.T)) 191 | y_sqnorms = torch.diag(torch.matmul(y, y.T)) 192 | 193 | gamma = 1 / (2 * CMMDEvaluator._SIGMA**2) 194 | k_xx = torch.mean( 195 | torch.exp(-gamma * (-2 * torch.matmul(x, x.T) + torch.unsqueeze(x_sqnorms, 1) + torch.unsqueeze(x_sqnorms, 0))) 196 | ) 197 | k_xy = torch.mean( 198 | torch.exp(-gamma * (-2 * torch.matmul(x, y.T) + torch.unsqueeze(x_sqnorms, 1) + torch.unsqueeze(y_sqnorms, 0))) 199 | ) 200 | k_yy = torch.mean( 201 | torch.exp(-gamma * (-2 * torch.matmul(y, y.T) + torch.unsqueeze(y_sqnorms, 1) + torch.unsqueeze(y_sqnorms, 0))) 202 | ) 203 | distance = k_xx + k_yy - 2 * k_xy 204 | results[f"cmmd_{encoder.id}"] = distance.item() 205 | return results 206 | 207 | 208 | class AestheticPredictorEvaluator(BaseEvaluator): 209 | TYPE = EvaluatorType.IMAGE_QUALITY 210 | HIGHER_IS_BETTER = True 211 | DEFAULT_URL = "https://github.com/christophschuhmann/improved-aesthetic-predictor/raw/main/sac+logos+ava1-l14-linearMSE.pth" 212 | 213 | def __init__(self, device: str, model_url: str = DEFAULT_URL): 214 | super().__init__(device) 215 | self.model_path = download_model(model_url, "aesthetic_predictor.pth") 216 | 217 | def evaluate(self, generated_images: list[Image.Image], **ignored_kwargs) -> Dict[str, float]: 218 | return {"aesthetic_predictor": 219 | run_inference(generated_images, self.model_path, self.device)} 220 | 221 | 222 | class ImageRewardEvaluator(BaseEvaluator): 223 | TYPE = EvaluatorType.CONTROLLABILITY 224 | HIGHER_IS_BETTER = True 225 | 226 | def __init__(self, device: str): 227 | super().__init__(device) 228 | self.evaluator = RM.load("ImageReward-v1.0") 229 | 230 | def evaluate(self, 231 | generated_images: list[Image.Image], 232 | prompts: list[str], 233 | **ignored_kwargs) -> Dict[str, float]: 234 | # Returns the average image reward 235 | rewards = [] 236 | for image, prompt in zip(generated_images, prompts): 237 | rewards.append(self.evaluator.score(prompt, image)) 238 | return {"image_reward": sum(rewards) / len(rewards)} 239 | 240 | 241 | class HumanPreferenceScoreEvaluator(BaseEvaluator): 242 | TYPE = EvaluatorType.CONTROLLABILITY 243 | HIGHER_IS_BETTER = True 244 | DEFAULT_URL = "https://mycuhk-my.sharepoint.com/:u:/g/personal/1155172150_link_cuhk_edu_hk/EWDmzdoqa1tEgFIGgR5E7gYBTaQktJcxoOYRoTHWzwzNcw?download=1" 245 | 246 | def __init__(self, device: str, model_url: str = DEFAULT_URL): 247 | super().__init__(device) 248 | self.hps_model_path = download_model(model_url, "human_preference_score.pt") 249 | self.model, self.preprocess = clip.load("ViT-L/14", device=self.device) 250 | 251 | if torch.cuda.is_available(): 252 | params = torch.load(self.hps_model_path)['state_dict'] 253 | else: 254 | params = torch.load(self.hps_model_path, map_location=self.device)['state_dict'] 255 | self.model.load_state_dict(params) 256 | 257 | def evaluate(self, 258 | generated_images: list[Image.Image], 259 | prompts: list[str], 260 | **ignored_kwargs) -> Dict[str, float]: 261 | images = [self.preprocess(img).to(self.device) for img in generated_images] 262 | images = torch.stack(images) 263 | texts = clip.tokenize(prompts, truncate=True).to(self.device) 264 | 265 | image_features = self.model.encode_image(images) 266 | text_features = self.model.encode_text(texts) 267 | 268 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 269 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 270 | 271 | hps = image_features @ text_features.T 272 | hps = hps.diagonal() 273 | return {"human_preference_score": torch.mean(hps).detach().item()} 274 | 275 | 276 | class VendiScoreEvaluator(BaseEvaluator): 277 | TYPE = EvaluatorType.DIVERSITY 278 | HIGHER_IS_BETTER = True 279 | 280 | def __init__(self, device: str): 281 | super().__init__(device) 282 | 283 | def evaluate(self, generated_images: list[Image.Image], **ignored_kwargs) -> Dict[str, float]: 284 | results = {} 285 | for encoder_cls in ALL_ENCODER_CLASSES: 286 | encoder = encoder_cls(self.device) 287 | embeddings = encoder.encode(generated_images).cpu().detach().numpy() 288 | results[f"vendi_score_{encoder.id}"] = vendi.score_X(embeddings).item() 289 | return results 290 | 291 | 292 | def get_evaluators_for_type(evaluator_type: EvaluatorType): 293 | return [evaluator for evaluator in globals().values() 294 | if isinstance(evaluator, type) 295 | and hasattr(evaluator, "TYPE") 296 | and evaluator.TYPE == evaluator_type] 297 | -------------------------------------------------------------------------------- /image_eval/improved_aesthetic_predictor.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import clip 4 | import pytorch_lightning as pl 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from PIL import Image 9 | 10 | """ 11 | Adapted from https://github.com/christophschuhmann/improved-aesthetic-predictor/tree/main 12 | 13 | This script will predict the aesthetic score for provided image files. 14 | """ 15 | 16 | 17 | # If you changed the MLP architecture during training, change it also here: 18 | class MLP(pl.LightningModule): 19 | def __init__(self, input_size, xcol='emb', ycol='avg_rating'): 20 | super().__init__() 21 | self.input_size = input_size 22 | self.xcol = xcol 23 | self.ycol = ycol 24 | self.layers = nn.Sequential( 25 | nn.Linear(self.input_size, 1024), 26 | nn.Dropout(0.2), 27 | nn.Linear(1024, 128), 28 | nn.Dropout(0.2), 29 | nn.Linear(128, 64), 30 | nn.Dropout(0.1), 31 | nn.Linear(64, 16), 32 | nn.Linear(16, 1) 33 | ) 34 | 35 | def forward(self, x): 36 | return self.layers(x) 37 | 38 | def training_step(self, batch, batch_idx): 39 | x = batch[self.xcol] 40 | y = batch[self.ycol].reshape(-1, 1) 41 | x_hat = self.layers(x) 42 | loss = F.mse_loss(x_hat, y) 43 | return loss 44 | 45 | def validation_step(self, batch, batch_idx): 46 | x = batch[self.xcol] 47 | y = batch[self.ycol].reshape(-1, 1) 48 | x_hat = self.layers(x) 49 | loss = F.mse_loss(x_hat, y) 50 | return loss 51 | 52 | def configure_optimizers(self): 53 | optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) 54 | return optimizer 55 | 56 | 57 | def normalized(a, axis=-1, order=2): 58 | import numpy as np # pylint: disable=import-outside-toplevel 59 | 60 | l2 = np.atleast_1d(np.linalg.norm(a, order, axis)) 61 | l2[l2 == 0] = 1 62 | return a / np.expand_dims(l2, axis) 63 | 64 | 65 | def run_inference(images: list[Image.Image], model_path: str, device: str): 66 | model = MLP(768) # CLIP embedding dim is 768 for CLIP ViT L 14 67 | if torch.cuda.is_available(): 68 | s = torch.load(model_path) 69 | else: 70 | s = torch.load(model_path, map_location=torch.device("cpu")) 71 | 72 | model.load_state_dict(s) 73 | 74 | model.to(device) 75 | model.eval() 76 | 77 | model2, preprocess = clip.load("ViT-L/14", device=device) 78 | avg_aesthetic_score = 0 79 | for pil_image in images: 80 | image = preprocess(pil_image).unsqueeze(0).to(device) 81 | with torch.no_grad(): 82 | image_features = model2.encode_image(image) 83 | 84 | im_emb_arr = normalized(image_features.cpu().detach().numpy()) 85 | 86 | if torch.cuda.is_available(): 87 | prediction = model(torch.from_numpy(im_emb_arr).to(device).type(torch.cuda.FloatTensor)) 88 | else: 89 | prediction = model(torch.from_numpy(im_emb_arr).to(device)) 90 | 91 | avg_aesthetic_score += prediction.item() 92 | 93 | return avg_aesthetic_score / len(images) 94 | -------------------------------------------------------------------------------- /image_eval/local_ab_test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import random 4 | from collections import Counter 5 | from collections import defaultdict 6 | 7 | import streamlit as st 8 | from PIL import Image 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument("--model-predictions-json", help="path to json file containing model predictions") 12 | args = parser.parse_args() 13 | 14 | 15 | def get_model_predictions_from_file(json_file: str) -> dict: 16 | model_preds = defaultdict(list) 17 | with open(json_file) as f: 18 | data = json.load(f) 19 | for entry in data: 20 | model_preds["model_1"].append(entry["model_1"]) 21 | model_preds["model_2"].append(entry["model_2"]) 22 | 23 | assert len(model_preds["model_1"]) == len(model_preds["model_2"]), \ 24 | "You must ensure you have an equal number of predictions per model" 25 | return model_preds 26 | 27 | 28 | def get_prompts_from_file(json_file: str): 29 | prompts = [] 30 | with open(json_file) as f: 31 | data = json.load(f) 32 | for d in data: 33 | if "prompt" in d: 34 | prompts.append(d["prompt"]) 35 | return prompts 36 | 37 | 38 | model_preds = get_model_predictions_from_file(args.model_predictions_json) 39 | images_1 = model_preds["model_1"] 40 | images_2 = model_preds["model_2"] 41 | 42 | prompts = get_prompts_from_file(args.model_predictions_json) 43 | 44 | # Question to evaluate 45 | col1, col2, col3 = st.columns([1, 3, 1]) 46 | with col1, col3: 47 | pass 48 | with col2: 49 | # Select choice for buttons 50 | selected_option = st.radio("Which image is more visually consistent with the prompt?", ["**A**", "**B**"], 51 | horizontal=True) 52 | 53 | 54 | def assign_images_and_prompt(): 55 | # Randomly assign one image to be A 56 | image_1 = images_1[st.session_state.curr_idx] 57 | image_2 = images_2[st.session_state.curr_idx] 58 | 59 | image_a = random.choice([image_1, image_2]) 60 | if image_a == image_1: 61 | st.session_state.model_a_assignments.append("model_1") 62 | image_b = image_2 63 | else: 64 | st.session_state.model_a_assignments.append("model_2") 65 | image_b = image_1 66 | 67 | image_a = Image.open(image_a) 68 | image_b = Image.open(image_b) 69 | st.session_state.image_a = image_a 70 | st.session_state.image_b = image_b 71 | 72 | if len(prompts) > 0: 73 | st.session_state.prompt = prompts[st.session_state.curr_idx] 74 | 75 | 76 | # Initialize session state 77 | if "curr_idx" not in st.session_state: 78 | st.session_state.curr_idx = 0 79 | st.session_state.click_disabled = False 80 | st.session_state.model_a_assignments = [] 81 | st.session_state.model_wins = Counter() 82 | st.session_state.wins_str = "" 83 | if "image_a" not in st.session_state: 84 | assign_images_and_prompt() 85 | 86 | 87 | def update_images_displayed(): 88 | if not st.session_state.click_disabled: 89 | # Increment model wins 90 | if st.session_state.curr_idx < len(images_1): 91 | if selected_option == "**A**": 92 | st.session_state.model_wins[st.session_state.model_a_assignments[st.session_state.curr_idx]] += 1 93 | else: 94 | if st.session_state.model_a_assignments[st.session_state.curr_idx] == "model_1": 95 | model_to_increment = "model_2" 96 | else: 97 | model_to_increment = "model_1" 98 | st.session_state.model_wins[model_to_increment] += 1 99 | st.session_state.curr_idx += 1 100 | if st.session_state.curr_idx >= len(images_1): 101 | st.session_state.click_disabled = True 102 | else: 103 | assign_images_and_prompt() 104 | 105 | 106 | def compute_scores_and_dump(): 107 | model_1_wins = st.session_state.model_wins["model_1"] / len(images_1) 108 | model_2_wins = st.session_state.model_wins["model_2"] / len(images_1) 109 | with open("scores.json", "w") as f: 110 | json.dump(st.session_state.model_wins, f) 111 | st.session_state.wins_str = f"Model 1 wins %: {model_1_wins * 100}, " \ 112 | f"Model 2 wins %: {model_2_wins * 100}" 113 | 114 | 115 | # Display images 116 | if not st.session_state.click_disabled: 117 | col1, col2 = st.columns(2) 118 | with col1: 119 | st.image(st.session_state.image_a, caption="Model A") 120 | with col2: 121 | st.image(st.session_state.image_b, caption="Model B") 122 | else: 123 | st.write("No more images left to process!") 124 | 125 | # Note about how many images have been processed 126 | st.markdown( 127 | """ 128 | 134 | """, unsafe_allow_html=True 135 | ) 136 | col1, col2, col3 = st.columns(3) 137 | with col1, col2: 138 | pass 139 | with col3: 140 | st.markdown( 141 | f"*Processing {st.session_state.curr_idx + 1 if not st.session_state.click_disabled else st.session_state.curr_idx}/{len(images_1)} images*") 142 | 143 | # Prompt provided to image gen systems 144 | col1, col2, col3 = st.columns(3) 145 | with col1, col3: 146 | pass 147 | with col2: 148 | st.write(f"Prompt: ***{prompts[st.session_state.curr_idx]}***" if ( 149 | len(prompts) > 0 and not st.session_state.click_disabled) else "") 150 | 151 | # Buttons to submit and compute win %s 152 | col1, col2, col3 = st.columns([3, 1, 1.5]) 153 | with col3: 154 | st.button("Submit", type="secondary", on_click=update_images_displayed, disabled=st.session_state.click_disabled) 155 | st.button("Compute Model Wins", type="secondary", on_click=compute_scores_and_dump) 156 | 157 | st.write(f"{st.session_state.wins_str}") 158 | -------------------------------------------------------------------------------- /image_eval/model_utils.py: -------------------------------------------------------------------------------- 1 | """Utilities for manipulating models.""" 2 | import logging 3 | import os 4 | import requests 5 | 6 | from pathlib import Path 7 | 8 | 9 | def download_model(url: str, output_filename: str, output_dir: str = "~/.cache"): 10 | """Downloads a file from the given URL. Returns the output path.""" 11 | if output_dir == "~/.cache": 12 | home = str(Path.home()) 13 | output_dir = os.path.join(home, ".cache") 14 | if not os.path.exists(output_dir): 15 | os.makedirs(output_dir) 16 | 17 | output_path = os.path.join(output_dir, output_filename) 18 | if os.path.exists(output_path): 19 | logging.info(f"Reusing model cached at {output_path}.") 20 | return output_path 21 | 22 | logging.info(f"Downloading model from {url}...") 23 | output_dir = os.path.dirname(output_path) 24 | os.makedirs(output_dir, exist_ok=True) 25 | 26 | response = requests.get(url, stream=True) 27 | if not response.status_code == 200: 28 | raise RuntimeError(f"Unable to download model from {url}.") 29 | 30 | with open(output_path, "wb") as f: 31 | # Iterate over the response content in chunks 32 | for chunk in response.iter_content(chunk_size=1024): 33 | if chunk: 34 | f.write(chunk) 35 | 36 | return output_path 37 | -------------------------------------------------------------------------------- /image_eval/pairwise_evaluators.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from PIL import Image 4 | from torchmetrics import MultiScaleStructuralSimilarityIndexMeasure 5 | from torchmetrics import PeakSignalNoiseRatio 6 | from torchmetrics import UniversalImageQualityIndex 7 | from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity 8 | from torchvision import transforms 9 | from typing import Any, Dict, List 10 | 11 | from image_eval.evaluators import BaseEvaluator 12 | from image_eval.evaluators import EvaluatorType 13 | 14 | 15 | class PairwiseSimilarityEvaluator(BaseEvaluator): 16 | """Evaluates how similar pairs of images are. 17 | 18 | Helpful to compare against images generated from consecutive checkpoints during training. When 19 | these metrics start signaling high similarity between consecutive checkpoints, it's a sign we 20 | can stop training, since the model has converged. 21 | 22 | All the children of these methods are metrics found in HEIM: 23 | https://crfm.stanford.edu/helm/heim/latest/ 24 | """ 25 | TYPE = EvaluatorType.PAIRWISE_SIMILARITY 26 | 27 | def __init__(self, device: str, metric_name: str, metric: Any, normalize: bool = False): 28 | """Constructs a pairwise evaluator that uses a given metric from `torchmetrics`. 29 | 30 | Args: 31 | device: "cpu" or "cuda" 32 | metric_name: A human-readable name for the metric 33 | metric: An object that has a __call__(images1, images2) method for pairwise comparison 34 | normalize: Whether images need to be normalized to [0, 1] before passing to the metric 35 | """ 36 | super().__init__(device) 37 | self.metric_name = metric_name 38 | self.metric = metric.to(device) 39 | self.normalize = normalize 40 | 41 | def evaluate(self, 42 | generated_images: List[Image.Image], 43 | real_images: List[Image.Image], 44 | **unused_kwargs) -> Dict[str, float]: 45 | if len(generated_images) != len(real_images): 46 | raise ValueError("Pairwise evaluators expect 1:1 pairs of generated/real images.") 47 | 48 | img_transforms = [transforms.Resize((256, 256)), transforms.ToTensor()] 49 | if self.normalize: 50 | img_transforms.append(transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))) 51 | 52 | preprocessing = transforms.Compose(img_transforms) 53 | generated_images = [preprocessing(img) for img in generated_images] 54 | real_images = [preprocessing(img) for img in real_images] 55 | 56 | generated_images = torch.stack(generated_images).to(self.device) 57 | real_images = torch.stack(real_images).to(self.device) 58 | 59 | score = self.metric(generated_images, real_images).detach().item() 60 | return {self.metric_name: score} 61 | 62 | 63 | class LPIPSEvaluator(PairwiseSimilarityEvaluator): 64 | """Calculates Learned Perceptual Image Patch Similarity (LPIPS) score. 65 | 66 | It computes the distance between the activations of two image patches for some deep network 67 | (by default, we use VGG). The score is between 0 and 1, where 0 means the images are identical. 68 | 69 | Original paper: https://arxiv.org/pdf/1801.03924 (published Apr 2018). 70 | """ 71 | HIGHER_IS_BETTER = False 72 | 73 | def __init__(self, device): 74 | super().__init__(device, "LPIPS", LearnedPerceptualImagePatchSimilarity(net_type="vgg"), 75 | normalize=True) 76 | 77 | 78 | class MultiSSIMEvaluator(PairwiseSimilarityEvaluator): 79 | """Calculates Multi-scale Structural Similarity Index Measure (SSIM). 80 | 81 | This is an extension of SSIM, which assesses the similarity between two images based on three 82 | components: luminance, contrast, and structure. The score is between -1 and 1, where 1 = perfect 83 | similarity, 0 = no similarity, -1 = perfect anti-corelation. 84 | 85 | Original paper: https://ieeexplore.ieee.org/document/1292216 (published Apr 2004). 86 | """ 87 | HIGHER_IS_BETTER = True 88 | 89 | def __init__(self, device): 90 | super().__init__(device, "MultiSSIM", MultiScaleStructuralSimilarityIndexMeasure()) 91 | 92 | 93 | class PSNREvaluator(PairwiseSimilarityEvaluator): 94 | """Calculates Peak Signal-to-Noise Ratio (PSNR). 95 | 96 | It was originally designed to measure the quality of reconstructed or compressed images compared 97 | to their original versions. Its values are between -infinity and +infinity, where identical 98 | images score +infinity. 99 | 100 | Original paper: https://ieeexplore.ieee.org/document/1163711 (published Sep 2000). 101 | """ 102 | HIGHER_IS_BETTER = True 103 | 104 | def __init__(self, device): 105 | super().__init__(device, "PSNR", PeakSignalNoiseRatio()) 106 | 107 | 108 | class UIQIEvaluator(PairwiseSimilarityEvaluator): 109 | """Calculates Universal Image Quality Index (UIQI). 110 | 111 | Based on the idea of comparing statistical properties of an original and a distorted image in 112 | both the spatial and frequency domains. The calculation involves several steps, including the 113 | computation of mean, variance, and covariance of the pixel values in local windows of the 114 | images. It also considers factors like luminance, contrast, and structure. 115 | 116 | Original paper: https://ieeexplore.ieee.org/document/1284395 (published Jul 2004). 117 | """ 118 | def __init__(self, device): 119 | super().__init__(device, "UIQUI", UniversalImageQualityIndex()) 120 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | image-reward==1.5 2 | insightface==0.7.3 3 | lpips==0.1.4 4 | networkx==3.1 5 | numpy 6 | onnxruntime==1.16.3 # required by insightface 7 | openai-clip==1.0.1 8 | piq==0.8.0 9 | pytorch-lightning==2.0.8 10 | requests==2.31.0 11 | scikit-learn==1.3.2 12 | streamlit==1.26.0 13 | sympy==1.12 14 | tabulate==0.9.0 15 | torch-fidelity==0.3.0 16 | transformers==4.40.1 17 | vendi-score==0.0.3 18 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages 2 | from setuptools import setup 3 | 4 | 5 | def readfile(filename): 6 | with open(filename, 'r+') as f: 7 | return f.read() 8 | 9 | 10 | setup( 11 | name="image-eval", 12 | version="0.1.8", 13 | description="A library for evaluating image generation models", 14 | long_description=readfile("README.md"), 15 | long_description_content_type="text/markdown", 16 | author="Storia AI", 17 | author_email="founders@storia.ai", 18 | py_modules=["eval"], 19 | license=readfile("LICENSE"), 20 | packages=find_packages(include=["image_eval", "image_eval.*"]), 21 | entry_points={ 22 | "console_scripts": [ 23 | "image_eval= eval:main" 24 | ] 25 | }, 26 | install_requires=open("requirements.txt").readlines(), 27 | ) 28 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Storia-AI/image-eval/a6c5dd5cf525e840571c5d4eaf7d824c49ea3b25/tests/__init__.py -------------------------------------------------------------------------------- /tests/assets/fortune_teller.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Storia-AI/image-eval/a6c5dd5cf525e840571c5d4eaf7d824c49ea3b25/tests/assets/fortune_teller.png -------------------------------------------------------------------------------- /tests/assets/julie.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Storia-AI/image-eval/a6c5dd5cf525e840571c5d4eaf7d824c49ea3b25/tests/assets/julie.jpeg -------------------------------------------------------------------------------- /tests/assets/julie_gn_mean0_sigma100.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Storia-AI/image-eval/a6c5dd5cf525e840571c5d4eaf7d824c49ea3b25/tests/assets/julie_gn_mean0_sigma100.jpeg -------------------------------------------------------------------------------- /tests/assets/julie_gn_mean0_sigma50.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Storia-AI/image-eval/a6c5dd5cf525e840571c5d4eaf7d824c49ea3b25/tests/assets/julie_gn_mean0_sigma50.jpeg -------------------------------------------------------------------------------- /tests/evaluators_test.py: -------------------------------------------------------------------------------- 1 | """ 2 | To run this test: 3 | 4 | pip install pytest 5 | export PYTHONPATH=$PYTHONPATH:/path/to/parent/of/image-eval 6 | pytest evaluators_test.py --log-cli-level=INFO 7 | """ 8 | import logging 9 | import torch 10 | import unittest 11 | 12 | from PIL import Image 13 | from image_eval.evaluators import ( 14 | AestheticPredictorEvaluator, 15 | EvaluatorType, 16 | get_evaluators_for_type, 17 | ) 18 | 19 | class TestEvaluators(unittest.TestCase): 20 | """Unit tests the evaluators by comparing scores of different image sets.""" 21 | 22 | @staticmethod 23 | def _single_eval(evaluator, image, other_image=None, prompt=None): 24 | """Extracts the score from the result of an evaluator, which is a {metric:score} dict.""" 25 | result = evaluator.evaluate([image], real_images=[other_image], prompts=[prompt]) 26 | assert len(result) == 1 27 | return list(result.values())[0] 28 | 29 | def test_image_quality_evaluators(self): 30 | """Tests that an original image scores higher than one with added Gaussian noise.""" 31 | img_orig = Image.open("assets/julie.jpeg") 32 | img_noise1 = Image.open("assets/julie_gn_mean0_sigma50.jpeg") 33 | img_noise2 = Image.open("assets/julie_gn_mean0_sigma100.jpeg") 34 | 35 | # TODO(julia): Add InceptionScoreEvaluator. 36 | for evaluator_cls in [AestheticPredictorEvaluator]: 37 | evaluator = evaluator_cls("cuda" if torch.cuda.is_available() else "cpu") 38 | score_orig = self._single_eval(evaluator, img_orig) 39 | score_noise1 = self._single_eval(evaluator, img_noise1) 40 | score_noise2 = self._single_eval(evaluator, img_noise2) 41 | logging.info(f"Scores from {evaluator}: {score_orig}, {score_noise1}, {score_noise2}") 42 | 43 | if evaluator_cls.HIGHER_IS_BETTER: 44 | self.assertGreater(score_orig, score_noise1) 45 | self.assertGreater(score_noise1, score_noise2) 46 | else: 47 | self.assertLess(score_orig, score_noise1) 48 | self.assertLess(score_noise1, score_noise2) 49 | 50 | def test_controllability_evaluators(self): 51 | """Tests that an image scores higher on a descriptive prompt than on a random prompt.""" 52 | img_dog = Image.open("assets/julie.jpeg") 53 | 54 | for evaluator_cls in get_evaluators_for_type(EvaluatorType.CONTROLLABILITY): 55 | evaluator = evaluator_cls("cuda" if torch.cuda.is_available() else "cpu") 56 | score_good_prompt = self._single_eval(evaluator, img_dog, prompt="a dog") 57 | score_bad_prompt = self._single_eval(evaluator, img_dog, prompt="a house") 58 | logging.info(f"Scores from {evaluator}: " 59 | f"good_prompt={score_good_prompt}, bad_prompt={score_bad_prompt}") 60 | 61 | if evaluator_cls.HIGHER_IS_BETTER: 62 | self.assertGreater(score_good_prompt, score_bad_prompt) 63 | else: 64 | self.assertLess(score_good_prompt, score_bad_prompt) 65 | 66 | def test_fidelity_evaluators(self): 67 | """Tests that datasets with similar images score higher than ones with different images.""" 68 | img_julie1 = Image.open("assets/julie.jpeg") 69 | img_julie2 = Image.open("assets/julie_gn_mean0_sigma50.jpeg") 70 | img_other = Image.open("assets/fortune_teller.png") 71 | 72 | for evaluator_cls in get_evaluators_for_type(EvaluatorType.FIDELITY): 73 | evaluator = evaluator_cls("cuda" if torch.cuda.is_available() else "cpu") 74 | 75 | # We're adding 2 images to each dataset, because some evaluators require at least 2. 76 | results_similar = evaluator.evaluate(generated_images=[img_julie2] * 2, 77 | real_images=[img_julie1] * 2) 78 | results_dissimilar = evaluator.evaluate(generated_images=[img_julie2] * 2, 79 | real_images=[img_other] * 2) 80 | 81 | # The same evaluator might return multiple metrics. 82 | for similar_key, similar_value in results_similar.items(): 83 | # TODO(julia & vlad): Figure out why insightface often fails or returns an empty 84 | # result even on pictures with human faces. 85 | if "insightface" in similar_key: 86 | continue 87 | 88 | dissimilar_value = results_dissimilar[similar_key] 89 | logging.info(f"Scores from {evaluator}/{similar_key}: " 90 | f"similar={similar_value}, dissimilar={dissimilar_value}") 91 | 92 | if evaluator_cls.HIGHER_IS_BETTER: 93 | self.assertGreater(similar_value, dissimilar_value) 94 | else: 95 | self.assertLess(similar_value, dissimilar_value) 96 | 97 | def test_diversity_evaluators(self): 98 | """Tests that a dataset with identical images scores higher than one with diverse images.""" 99 | img_julie1 = Image.open("assets/julie.jpeg") 100 | img_julie2 = Image.open("assets/julie_gn_mean0_sigma50.jpeg") 101 | img_other = Image.open("assets/fortune_teller.png") 102 | 103 | for evaluator_cls in get_evaluators_for_type(EvaluatorType.DIVERSITY): 104 | evaluator = evaluator_cls("cuda" if torch.cuda.is_available() else "cpu") 105 | 106 | results_not_diverse = evaluator.evaluate([img_julie1, img_julie2]) 107 | results_diverse = evaluator.evaluate([img_julie1, img_other]) 108 | 109 | for key_not_diverse, value_not_diverse in results_not_diverse.items(): 110 | # TODO(julia & vlad): Figure out why insightface often fails or returns an empty 111 | # result even on pictures with human faces. 112 | if "insightface" in key_not_diverse: 113 | continue 114 | 115 | value_diverse = results_diverse[key_not_diverse] 116 | logging.info(f"Scores from {evaluator}/{key_not_diverse}: " 117 | f"not_diverse={value_not_diverse}, diverse={value_diverse}") 118 | 119 | if evaluator_cls.HIGHER_IS_BETTER: 120 | self.assertGreater(value_diverse, value_not_diverse) 121 | else: 122 | self.assertLess(value_diverse, value_not_diverse) 123 | 124 | 125 | if __name__ == "__main__": 126 | unittest.main() 127 | --------------------------------------------------------------------------------