├── config └── config_vis │ ├── config.yaml │ └── vis │ ├── default.yaml │ └── resnet18.yaml ├── .gitignore ├── imgs ├── tsne_resnet18_race.png ├── tsne_resnet18_gender.png └── latentplay_teaser.drawio.png ├── scripts ├── paths.sh ├── download_dlib68.sh └── download_fairface.sh ├── env.yaml ├── gen_dataset.py ├── LICENSE ├── add_embedding.py ├── utils.py ├── README.md ├── loader_configs.py └── visualization.py /config/config_vis/config.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - vis: default 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **/__pycache__ 2 | **/multirun 3 | **/models 4 | **/Experiments 5 | **/datasets -------------------------------------------------------------------------------- /imgs/tsne_resnet18_race.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/parsa-ra/LatentPlayInterface/HEAD/imgs/tsne_resnet18_race.png -------------------------------------------------------------------------------- /imgs/tsne_resnet18_gender.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/parsa-ra/LatentPlayInterface/HEAD/imgs/tsne_resnet18_gender.png -------------------------------------------------------------------------------- /imgs/latentplay_teaser.drawio.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/parsa-ra/LatentPlayInterface/HEAD/imgs/latentplay_teaser.drawio.png -------------------------------------------------------------------------------- /config/config_vis/vis/default.yaml: -------------------------------------------------------------------------------- 1 | name: 'test' 2 | dataset_full_path: "./datasets/" 3 | embedding_key: resnet18 4 | proj_algorithm: tsne 5 | discard: False 6 | max_unique_values: 10 7 | label_keys: 8 | - gender -------------------------------------------------------------------------------- /scripts/paths.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | LP_FAIRFACE_PATH=/path/to/dataset 4 | LP_MORPH_PATH=/path/to/dataset 5 | LP_RFW_PATH=/path/to/dataset 6 | LP_UTKFACE_PATH=/path/to/dataset 7 | 8 | LP_DLIB_PREDICTOR=/path/to/dlib_predictor 9 | -------------------------------------------------------------------------------- /config/config_vis/vis/resnet18.yaml: -------------------------------------------------------------------------------- 1 | name: 'test' 2 | dataset_full_path: "datasets/fairface2latentplay_resnet18.hf" 3 | embedding_key: resnet18 4 | proj_algorithm: tsne 5 | discard: False 6 | max_unique_values: 10 7 | label_keys: 8 | - gender -------------------------------------------------------------------------------- /scripts/download_dlib68.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | mkdir -p models && cd models 4 | wget http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2 5 | bzip2 -d shape_predictor_68_face_landmarks.dat.bz2 6 | cd .. 7 | export LP_DLIB_PREDICTOR="$(pwd)/models/shape_predictor_68_face_landmarks.dat" -------------------------------------------------------------------------------- /env.yaml: -------------------------------------------------------------------------------- 1 | name: latentplay 2 | channels: 3 | - conda-forge 4 | - pytorch 5 | - nvidia 6 | - pytorch3d 7 | - huggingface 8 | dependencies: 9 | - python=3.10 10 | - dlib 11 | - pytorch 12 | - torchvision 13 | - datasets 14 | - scipy 15 | - hydra-core 16 | - omegaconf 17 | - gdown 18 | - bokeh 19 | - Pillow 20 | - tqdm 21 | - pip: 22 | - python-json-logger 23 | - timm 24 | - umap-learn -------------------------------------------------------------------------------- /scripts/download_fairface.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | echo "Creating the datasets directory" 4 | mkdir -p datasets/fairface && cd datasets/fairface 5 | echo "Downloading the fairface dataset ... " 6 | 7 | gdown 1Z1RqRo0_JiavaZw2yzZG6WETdZQ8qX86 # From the fairface repository 8 | gdown 1i1L3Yqwaio7YSOCj7ftgk8ZZchPG7dmH 9 | gdown 1wOdja-ezstMEp81tX1a-EYkFebev4h7D 10 | 11 | unzip fairface-img-margin025-trainval.zip 12 | 13 | export LP_FAIRFACE_PATH=$(pwd)/datasets/fairface 14 | -------------------------------------------------------------------------------- /gen_dataset.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset, logging, disable_caching 2 | from argparse import ArgumentParser 3 | from multiprocessing import cpu_count 4 | import os 5 | 6 | parser = ArgumentParser() 7 | parser.add_argument("dataset_name", choices=[ 8 | "fairface", 9 | "morph", 10 | "rfw", 11 | "utkface", 12 | ]) 13 | parser.add_argument("--output_path", default="datasets") 14 | 15 | args = parser.parse_args() 16 | os.makedirs(args.output_path, exist_ok=True) 17 | output_path = args.output_path 18 | 19 | logging.set_verbosity_info() 20 | 21 | dataset_builder_name = args.dataset_name + "2latentplay" 22 | data = load_dataset(f"./builders/{dataset_builder_name}", num_proc=cpu_count()-1) 23 | 24 | output_path = os.path.join(args.output_path, dataset_builder_name + ".hf") 25 | print(f"Dataset will be saved at: {output_path}") 26 | data.save_to_disk(output_path) 27 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Parsa Rahimi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /add_embedding.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | from datasets import load_from_disk 4 | from io import BytesIO 5 | from PIL import Image 6 | from traceback import print_exc 7 | import torch 8 | from torchvision.models import resnet18, ResNet18_Weights 9 | 10 | from argparse import ArgumentParser 11 | 12 | parser = ArgumentParser() 13 | parser.add_argument('dataset_path') 14 | 15 | args = parser.parse_args() 16 | 17 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 18 | 19 | dataset = load_from_disk(args.dataset_path) 20 | 21 | aligned_key = "image" 22 | 23 | arch = "resnet18" 24 | weights = ResNet18_Weights.DEFAULT 25 | preprocess = weights.transforms() 26 | model = resnet18() 27 | 28 | model.to(device) 29 | 30 | output_key = arch 31 | output_status_key = "status_" + arch 32 | 33 | 34 | def embedder(sample): 35 | try: 36 | # Loading the aligned image 37 | img = preprocess(Image.open(BytesIO(sample[aligned_key])).convert('RGB')) 38 | 39 | print(img.shape) 40 | img = img.unsqueeze(0) 41 | 42 | img = img.to(device) 43 | feats = model(img) 44 | 45 | feats = feats.detach().to('cpu') 46 | 47 | #print(feats.shape) 48 | sample[output_status_key] = True 49 | sample[output_key] = feats[0,:] 50 | 51 | del img, feats 52 | 53 | except: 54 | print(f"Failed for sample ") 55 | print_exc() 56 | 57 | sample[output_status_key] = False 58 | sample[output_key] = torch.zeros(512, dtype=torch.float32) 59 | 60 | return sample 61 | 62 | dataset = dataset.map(embedder, writer_batch_size=1000) 63 | 64 | dataset.save_to_disk( args.dataset_path[:-3] + f"_{arch}.hf") 65 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import logging.config as config 3 | from logging.handlers import TimedRotatingFileHandler 4 | from pythonjsonlogger.jsonlogger import JsonFormatter 5 | import sys 6 | import os 7 | 8 | def get_env_key_val(name: str, default: bool): 9 | return os.getenv(name) if os.getenv(name) else default 10 | 11 | stdout_loglvl = get_env_key_val("STDOUT_LOG_LEVEL", logging.WARNING) 12 | if not stdout_loglvl in logging._nameToLevel: 13 | stdout_loglvl = logging.DEBUG 14 | 15 | print("Setting stdout loglvl to {}".format(stdout_loglvl)) 16 | file_loglvl = get_env_key_val("FILE_LOG_LEVEL", logging.DEBUG) 17 | if not file_loglvl in logging._nameToLevel: 18 | file_loglvl = logging.DEBUG 19 | print("Setting file loglvl to {}".format(file_loglvl)) 20 | 21 | base_directory = "logs" 22 | dirname = os.path.dirname(os.path.abspath(__file__)) 23 | parent_folder_name = dirname.split('/')[-1] 24 | APP_NAME = parent_folder_name 25 | APP_VERSION = "0.1" 26 | 27 | def get_logger(name: str, add_stdout=True, add_file=True, default_log_level=logging.DEBUG): 28 | logger = logging.getLogger(name) 29 | logger.setLevel(stdout_loglvl) 30 | 31 | fmt = logging.Formatter('%(name)s | %(levelname)s | %(asctime)s | %(filename)s | %(funcName)s | %(lineno)s | %(message)s') 32 | json_fmt = JsonFormatter('%(levelname)s %(message)s %(asctime)s %(name)s %(funcName)s %(lineno)d %(thread)d %(pathname)s', json_ensure_ascii=False) 33 | 34 | 35 | if add_stdout: 36 | stream_handler = logging.StreamHandler(sys.stdout) 37 | stream_handler.setLevel(stdout_loglvl) 38 | stream_handler.setFormatter(fmt) 39 | logger.addHandler(stream_handler) 40 | 41 | if add_file: 42 | # file_handler = logging.FileHandler(f"./logs_{name}.txt", encoding='utf-8') 43 | file_handler = TimedRotatingFileHandler(os.path.join(base_directory ,f"./logs_{name}.json"), when='d', interval=1, backupCount=60) 44 | file_handler.setLevel(file_loglvl) 45 | file_handler.setFormatter(json_fmt) 46 | logger.addHandler(file_handler) 47 | 48 | if not add_stdout and not add_file: 49 | print("You should at least set one of the handlers to True") 50 | sys.exit(0) 51 | 52 | return logger -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # LatentPlay Interface 3 | 4 | ![teaser](./imgs/latentplay_teaser.drawio.png) 5 | 6 | Assume that you have a dataset of images and multiple labels for each of the images, for example, images of human faces and their corresponding labels like the gender, race and age of the person, or images of objects with their corresponding name from discrete set of possible values. Now consider the case that we want to calculate the embedding of each of the images using different algorithms, or store the post-processed versions of the images, the practice now is to store each of these data types in different files or folders (for example an .npy file for embeddings and a separate folder to save the post-processed images and save the path to the folder in the csv file), here we introducing a way to unify all of this storage and reusability issues using our proposed `LatentPlay` interface which tries to simplify this process using the 🤗 [Datasets](https://huggingface.co/docs/datasets/en/index) library. 7 | 8 | 9 | ## TLDR: 10 | A **practice** to handle multi-modal datasets in a unified way. 11 | 12 | This folder contains scripts to convert various datasets to my unified latentplay interface based on `huggingfaces dataset` for facilitating further experimentation with minimal code changes. You can also find some sample usages in next sections. 13 | 14 | Currently it supports following datasets: 15 | - [FairFace](https://github.com/joojs/fairface) 16 | - [MORPH](https://ieeexplore.ieee.org/document/1613043) 17 | - [RFW](http://www.whdeng.cn/RFW/index.html) 18 | - [UTKFace](https://susanqq.github.io/UTKFace/) 19 | 20 | Pull requests are welcomed 🤝🏻. 21 | 22 | Given the path to the original datasets as they originally published, it will generate the `.hf` folder with the datasets packed into a binary files. 23 | 24 | **Note**: The scripts designed for the original distribution of datasets (in terms of folder hierarchy) by the corresponding authors. 25 | 26 | Later you can simple load the dataset in your python script like: 27 | ```python 28 | from datasets import load_from_disk 29 | dataset = load_from_disk('path/to/your/dataset.hf') 30 | ``` 31 | 32 | 33 | ## Getting Started 34 | Creating conda environment 35 | ```bash 36 | mamba/conda/micromamba create -f ./env.yaml 37 | mamba activate latentplay 38 | ``` 39 | 40 | 41 | ## Making LatentPlay version of Fairface [OPTIONAL] 42 | 43 | Downloading the [FairFace](https://drive.google.com/file/d/1Z1RqRo0_JiavaZw2yzZG6WETdZQ8qX86/view) dataset. 44 | 45 | We provide a helper script for that, just run: 46 | ```bash 47 | chmod +x ./scripts/download_fairface.sh 48 | download_fairface.sh 49 | ``` 50 | 51 | I also provide the preprocessed version of this dataset you can download it from release. 52 | 53 | 54 | # Playing around with the dataset 55 | Now that you have a dataset and all of its metadata you can play around with it, for example, adding new column to the dataset in terms of the embeddings of the images (yes you can add vectors to the datasets which already contains images, meta data in different formats!!) so its make it easy for us to manage our experiments. 56 | 57 | ## Sample#1: Adding Embedding as a new column 58 | Following script will extract the embedding of the `image` column in our dataset and save it in new `resnet18` column. 59 | Optionally you can download the dataset with pre-calculated embedding from release. 60 | 61 | ```bash 62 | python3 add_embedding.py datasets/fairface2latentplay.hf 63 | ``` 64 | 65 | The new dataset with the new embedding column will be saved next to the original one. 66 | 67 | 68 | ## Sample#2: Visualization of the dataset using t-SNE 69 | Now that we have both embedding, and their corresponding labels in a **unified dataset**, lets see how we can draw a `t-SNE` plot from embeddings according to different attribute groups (e.g. gender, race, ...). 70 | 71 | ```bash 72 | python3 visualization.py --multirun vis=resnet18 73 | ``` 74 | The plots will be saved next to the dataset path. 75 | 76 | Or you can do the same for the `race` column in the dataset. 77 | ```bash 78 | python3 visualization.py --multirun vis=resnet18 vis.label_keys=['race'] 79 | ``` 80 | 81 | | Gender | Race | 82 | |--------------------------------------|--------------------------------------| 83 | | | | 84 | 85 | 86 | We can also draw the plots using `umap`. in this case you can run: 87 | ```bash 88 | python3 visualization.py --multirun vis=resnet18 vis.proj_algorithm=umap 89 | ``` 90 | 91 | # TODO: 92 | Add more sample use cases ... 93 | 94 | # Disclaimer 95 | This repository is part of our full paper code release originally published [here](https://gitlab.idiap.ch/biometric/sg_latent_modeling). 96 | 97 | # Citation 98 | If you found this framework useful please consider citing our paper. 99 | 100 | ```bibtex 101 | @inproceedings{rahimi2023toward, 102 | title={Toward responsible face datasets: modeling the distribution of a disentangled latent space for sampling face images from demographic groups}, 103 | author={Rahimi, Parsa and Ecabert, Christophe and Marcel, S{\'e}bastien}, 104 | booktitle={2023 IEEE International Joint Conference on Biometrics (IJCB)}, 105 | pages={1--11}, 106 | year={2023}, 107 | organization={IEEE} 108 | } 109 | ``` -------------------------------------------------------------------------------- /loader_configs.py: -------------------------------------------------------------------------------- 1 | import datasets 2 | import io 3 | import os 4 | from glob import glob 5 | from tqdm import tqdm 6 | import numpy as np 7 | import PIL 8 | from PIL import Image 9 | import scipy 10 | import dlib 11 | 12 | from datasets import Sequence 13 | 14 | preprocessing_target_size = 256 15 | 16 | # landmark detector 17 | dlib_landmark_detector_path = os.getenv("LP_DLIB_PREDICTOR") 18 | 19 | naming_splitter = "~" 20 | 21 | ######################## Specs ########################### 22 | 23 | male = "male" 24 | female = "female" 25 | 26 | race_map = { 27 | "asian": "asian", 28 | "black": "black", 29 | "african": "black", 30 | "caucasian": "white", 31 | "white": "white", 32 | "middleeastern": "middleeastern", 33 | "indian": "indian", 34 | "hispanic": "hispanic", 35 | "unknown": "unknown", 36 | 37 | "O": "unkown", 38 | "A": "asian", 39 | "B": "black", 40 | "H": "hispanic", 41 | "W": "white", 42 | 43 | "East Asian": "asian", # TODO 44 | "Indian": "indian", 45 | "Black": "black", 46 | "Middle Eastern": "middleeastern", 47 | "White": "white", 48 | "Latino_Hispanic": "hispanic", 49 | "Southeast Asian": "asian", #TODO 50 | 51 | # From https://susanqq.github.io/UTKFace/ 52 | # utkface 53 | "0": "white", 54 | "1": "black", 55 | "2": "asian", 56 | "3": "indian", 57 | "4": "unknown" 58 | } 59 | 60 | gender_map = { 61 | "m": "male", 62 | "f": "female", 63 | "M": "male", 64 | "male": "male", 65 | "man": "male", 66 | "F": "female", 67 | "women": "female", 68 | "female": "female", 69 | "unknown": "unknown", 70 | 71 | "Male": "male", 72 | "Female": "female", 73 | 74 | # From https://susanqq.github.io/UTKFace/ 75 | # utkface 76 | "0": "male", 77 | "1": "female", 78 | } 79 | 80 | 81 | # TODO: add some kind of versioning of sort 82 | 83 | dataset_features = { 84 | "person_id": datasets.Value("string"), # Unique Id that will be used to identify if the persons are from the same identity or not ... 85 | "image": datasets.Value("binary"), 86 | 87 | "dlib_align_status": datasets.Value("bool"), 88 | "image_dlib_aligned": datasets.Value("binary"), 89 | 90 | 91 | "gender": datasets.Value("string"), 92 | "race": datasets.Value("string"), 93 | "age": datasets.Value("string"), 94 | "human": datasets.Value("bool"), # Used for human non-human stuff ... 95 | } 96 | 97 | 98 | def pre_process_images(raw_image_path, output_path, predictor): 99 | current_directory = os.getcwd() 100 | print(current_directory) 101 | 102 | aligned_images = [] 103 | try: 104 | aligned_image = align_face(filepath=raw_image_path, 105 | predictor=predictor, output_size=preprocessing_target_size) 106 | aligned_images.append(aligned_image) 107 | except Exception as e: 108 | print(e) 109 | 110 | os.makedirs(output_path, exist_ok=True) 111 | images_names = [raw_image_path.split('/')[-1]] 112 | for image, name in zip(aligned_images, images_names): 113 | # Name without extensions 114 | real_name = name.split('.')[0] 115 | image.save(f'{output_path}/{real_name}.jpeg') 116 | 117 | os.chdir(current_directory) 118 | 119 | ## Borrowed from Insightfaces repository 120 | def get_landmark(filepath, predictor): 121 | """get landmark with dlib 122 | :return: np.array shape=(68, 2) 123 | """ 124 | detector = dlib.get_frontal_face_detector() 125 | 126 | img = dlib.load_rgb_image(filepath) 127 | dets = detector(img, 1) 128 | 129 | for k, d in enumerate(dets): 130 | shape = predictor(img, d) 131 | 132 | t = list(shape.parts()) 133 | a = [] 134 | for tt in t: 135 | a.append([tt.x, tt.y]) 136 | lm = np.array(a) 137 | return lm 138 | 139 | 140 | def align_face(filepath, predictor, output_size): 141 | """ 142 | :param filepath: str 143 | :return: PIL Image 144 | """ 145 | 146 | lm = get_landmark(filepath, predictor) 147 | 148 | lm_chin = lm[0: 17] # left-right 149 | lm_eyebrow_left = lm[17: 22] # left-right 150 | lm_eyebrow_right = lm[22: 27] # left-right 151 | lm_nose = lm[27: 31] # top-down 152 | lm_nostrils = lm[31: 36] # top-down 153 | lm_eye_left = lm[36: 42] # left-clockwise 154 | lm_eye_right = lm[42: 48] # left-clockwise 155 | lm_mouth_outer = lm[48: 60] # left-clockwise 156 | lm_mouth_inner = lm[60: 68] # left-clockwise 157 | 158 | # Calculate auxiliary vectors. 159 | eye_left = np.mean(lm_eye_left, axis=0) 160 | eye_right = np.mean(lm_eye_right, axis=0) 161 | eye_avg = (eye_left + eye_right) * 0.5 162 | eye_to_eye = eye_right - eye_left 163 | mouth_left = lm_mouth_outer[0] 164 | mouth_right = lm_mouth_outer[6] 165 | mouth_avg = (mouth_left + mouth_right) * 0.5 166 | eye_to_mouth = mouth_avg - eye_avg 167 | 168 | # Choose oriented crop rectangle. 169 | x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1] 170 | x /= np.hypot(*x) 171 | x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8) 172 | y = np.flipud(x) * [-1, 1] 173 | c = eye_avg + eye_to_mouth * 0.1 174 | quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y]) 175 | qsize = np.hypot(*x) * 2 176 | 177 | # read image 178 | img = PIL.Image.open(filepath) 179 | 180 | transform_size = output_size 181 | enable_padding = True 182 | 183 | # Shrink. 184 | shrink = int(np.floor(qsize / output_size * 0.5)) 185 | if shrink > 1: 186 | rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink))) 187 | img = img.resize(rsize, PIL.Image.ANTIALIAS) 188 | quad /= shrink 189 | qsize /= shrink 190 | 191 | # Crop. 192 | border = max(int(np.rint(qsize * 0.1)), 3) 193 | crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))), 194 | int(np.ceil(max(quad[:, 1])))) 195 | crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]), 196 | min(crop[3] + border, img.size[1])) 197 | if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]: 198 | img = img.crop(crop) 199 | quad -= crop[0:2] 200 | 201 | # Pad. 202 | pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))), 203 | int(np.ceil(max(quad[:, 1])))) 204 | pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0), 205 | max(pad[3] - img.size[1] + border, 0)) 206 | if enable_padding and max(pad) > border - 4: 207 | pad = np.maximum(pad, int(np.rint(qsize * 0.3))) 208 | img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect') 209 | h, w, _ = img.shape 210 | y, x, _ = np.ogrid[:h, :w, :1] 211 | mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w - 1 - x) / pad[2]), 212 | 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h - 1 - y) / pad[3])) 213 | blur = qsize * 0.02 214 | img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0) 215 | img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0) 216 | img = PIL.Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB') 217 | quad += pad[:2] 218 | 219 | # Transform. 220 | img = img.transform((transform_size, transform_size), PIL.Image.QUAD, (quad + 0.5).flatten(), PIL.Image.BICUBIC) 221 | if output_size < transform_size: 222 | img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS) 223 | 224 | # Return aligned image. 225 | return img -------------------------------------------------------------------------------- /visualization.py: -------------------------------------------------------------------------------- 1 | from enum import unique 2 | from utils import get_logger 3 | from sklearn.manifold import TSNE 4 | from bokeh.plotting import figure, show 5 | from bokeh.transform import factor_cmap, factor_mark 6 | from bokeh.layouts import gridplot 7 | from bokeh.models import ColumnDataSource 8 | from typing import List 9 | from copy import copy, deepcopy 10 | from joblib import dump, load 11 | from glob import glob 12 | import numpy as np 13 | import pandas as pd 14 | from datasets import load_from_disk, Dataset 15 | from copy import copy 16 | import matplotlib.pyplot as plt 17 | import numpy as np 18 | import matplotlib 19 | import hydra 20 | from omegaconf import DictConfig, OmegaConf 21 | from multiprocessing import cpu_count 22 | 23 | import os 24 | 25 | logger = get_logger("Manifold", True, False) 26 | 27 | experiments_base_path = "./Experiments" 28 | projection_base_path = os.path.join(experiments_base_path, 'projection') 29 | 30 | fw, fh = 1200, 1200 31 | 32 | np.random.seed(13) 33 | 34 | def generate_distinct_colors(n): 35 | colors = plt.cm.get_cmap("hsv", n) 36 | return [matplotlib.colors.rgb2hex(colors(i)[:3]) for i in range(n)] 37 | 38 | 39 | MARKERS = ['hex', 'circle_x', 'triangle', 'square', 'circle', 'dot', 'asterisk', 'diamond'] 40 | COLORS = generate_distinct_colors(20) 41 | np.random.shuffle(COLORS) 42 | print(COLORS) 43 | 44 | change_label_name = False 45 | embeddings = None 46 | metas = None 47 | FORCE_RERUN = False 48 | 49 | print("Creating projection base path ... ") 50 | os.makedirs(projection_base_path, exist_ok=True) 51 | 52 | colormap = dict() 53 | markermap = dict() 54 | 55 | @hydra.main(version_base=None, config_path="config/config_vis", config_name="config.yaml") 56 | def tsne(cfg: DictConfig): 57 | cfg = cfg.vis 58 | 59 | name = cfg.name 60 | 61 | dataset_full_path = cfg.dataset_full_path 62 | print(dataset_full_path) 63 | 64 | embedding_key = cfg.embedding_key 65 | label_keys = cfg.label_keys 66 | 67 | assert len(label_keys) <= 2 , "Right know for tnsne we can only show two label at a time" 68 | label_map = dict() 69 | 70 | label_precedence = ["colormap", "markermap"] 71 | value_precedence = [ "COLORS", "MARKERS" ] 72 | values = ["color", "marker"] 73 | 74 | dataset: Dataset = load_from_disk(dataset_full_path) 75 | dataset_keys = list() 76 | dataset_keys.append(embedding_key) 77 | dataset_keys.extend(label_keys) 78 | 79 | #prefix_name= dataset['train']._fingerprint # Inferfrom name or fingerprint 80 | prefix_name = f"{cfg.proj_algorithm.upper()}s" 81 | exp_name = "-".join([prefix_name, embedding_key, *label_keys, dataset['train']._fingerprint]) 82 | 83 | status_prefix = "status" 84 | for key in dataset: 85 | status_key_name = status_prefix + "_"+embedding_key 86 | dataset_check = dataset[key].select_columns(column_names=[embedding_key, status_key_name]) 87 | dataset_check = dataset_check.to_pandas() 88 | print(dataset_check[status_key_name].value_counts().to_dict()) 89 | 90 | 91 | if cfg.discard: 92 | print(f"Filtering enteries with their {cfg.discard + embedding_key} set to true ...") 93 | dataset = dataset.filter(lambda entry: entry[ cfg.discard + embedding_key], 94 | num_proc=min(1,cpu_count()//2)) # There is some problem with the datasets' filtering process with multiple processes 95 | print("Filtering finished.") 96 | 97 | dataset.set_format("numpy", columns=dataset_keys) 98 | 99 | if 'test' not in dataset and 'train' in dataset: 100 | logger.warn('Test shard is not in the dataset, will split the dataset automatically') 101 | dataset = dataset['train'].train_test_split(test_size=0.1) 102 | else: 103 | logger.error('Unknown split names, the dataset should either contain an `train` or `test`split') 104 | 105 | train_embeddings = dataset['train'][embedding_key] 106 | test_embeddings = dataset['test'][embedding_key] 107 | 108 | test_labels = list() 109 | 110 | for label_key in label_keys: 111 | test_labels.append(dataset['test'][label_key]) 112 | 113 | unique_keys: List[List] = list() 114 | 115 | for label_idx, label_key in enumerate(label_keys): 116 | unique_keys.append(list()) 117 | 118 | cur_labels = test_labels[label_idx] 119 | for cur_label in cur_labels: 120 | if cur_label not in unique_keys[-1]: 121 | print(f"Adding {cur_label} for {label_key}") 122 | unique_keys[-1].append(cur_label) 123 | 124 | 125 | # import sys 126 | print("\n\n\n", unique_keys) 127 | max_unique_key_len = 0 128 | for unique_key in unique_keys: 129 | max_unique_key_len = max(len(unique_key), max_unique_key_len) 130 | print(max_unique_key_len, "\n\n\n") 131 | # sys.exit(0) 132 | 133 | plots = list() 134 | 135 | title="-".join([prefix_name, embedding_key, *label_keys]) 136 | segmented_title = "" 137 | cur_pos = 0 138 | pos_step = 25 139 | while cur_pos < len(title): 140 | segmented_title += title[cur_pos: cur_pos+pos_step] + "\n" 141 | cur_pos += pos_step 142 | 143 | 144 | print(f"Original Title was: {title}\nSegmented title is: {segmented_title}\n") 145 | p = figure(title = segmented_title, frame_width=fw, frame_height=fh, background_fill_color="#fafafa") 146 | 147 | cur_embeddings = train_embeddings 148 | projector_name = exp_name 149 | projector_path = os.path.join(projection_base_path ,f"{projector_name}_manifoldlearning.pickle") 150 | projector_sklearn_api = None 151 | if os.path.exists(projector_path) and not FORCE_RERUN: 152 | logger.info(f"It seems that the projector file for current config already exists in {projector_path}, ... Loading it from disk instead of re-running the tsne") 153 | with open(projector_path, 'rb') as file: 154 | projector_sklearn_api = load(file) 155 | else: 156 | logger.info(f"The tsne file on disk cannot be found at {projector_path}, running the tsne training ... ") 157 | 158 | if cfg.proj_algorithm == "tsne": 159 | logger.info(f"Setting projection algorithm to TSNE") 160 | projector_sklearn_api = TSNE( 161 | n_components=2, 162 | learning_rate="auto", 163 | perplexity=50, 164 | n_iter=1000, 165 | n_jobs=-1, 166 | init="pca" 167 | ) 168 | elif cfg.proj_algorithm == "umap": 169 | logger.info(f"Setting projection algorithm to UMAP") 170 | import umap 171 | projector_sklearn_api = umap.UMAP() 172 | else: 173 | logger.error(f"Unsupported projection algorithm {cfg.proj_algorithm}") 174 | raise ValueError() 175 | 176 | projector_sklearn_api.fit_transform(train_embeddings) 177 | 178 | with open(projector_path, 'wb') as file: 179 | dump(projector_sklearn_api, file) 180 | 181 | colormap = dict() 182 | markermap = dict() 183 | 184 | for label_idx, label_key in enumerate(label_keys): 185 | print(label_key, label_idx) 186 | for idx, entry in enumerate(unique_keys[label_idx]): 187 | print(entry, idx) 188 | print(label_precedence[label_idx], value_precedence[label_idx]) 189 | 190 | #print(globals()[label_precedence[label_idx]]) 191 | #print(globals()[value_precedence[label_idx]]) 192 | 193 | globals()[label_precedence[label_idx]][entry] = globals()[value_precedence[label_idx]][idx] 194 | 195 | colormap[entry] = COLORS[idx] 196 | 197 | # colormap[entry] = COLORS[idx] 198 | # markermap[entry] = MARKERS[idx] 199 | 200 | #print("\n\n\nColorMap", colormap) 201 | 202 | # if change_label_name: 203 | # mapping_label_name = dict() 204 | # for key in unique_keys: 205 | # alternative_name = str(input(f"Enter name alternative name for the {key}: ")) 206 | # mapping_label_name[key] = alternative_name 207 | # for idx, meta in enumerate(metas): 208 | # metas[idx] = mapping_label_name[meta] 209 | 210 | projected = projector_sklearn_api.fit_transform(test_embeddings) 211 | 212 | #print(list(colormap.keys())) 213 | source = ColumnDataSource(dict( 214 | x=projected[:,0], 215 | y=projected[:,1], 216 | # marker = [ markermap[meta] for meta in test_labels[1] ], 217 | color = [ colormap[meta] for meta in test_labels[0] ], 218 | label = test_labels[0] 219 | ) 220 | ) 221 | 222 | print("Plotting scatter plot ... ") 223 | p.scatter( 224 | x = 'x', 225 | y = 'y', 226 | source = source, 227 | size = 14, 228 | color ='color', 229 | #marker ='marker', 230 | legend_group ='label', 231 | fill_alpha =0.5, 232 | ) 233 | p.legend.location = "top_left" 234 | p.legend.title = "-".join([prefix_name, embedding_key, *label_keys]) 235 | print("plotting scatter done") 236 | 237 | p.legend.background_fill_alpha = 0.3 238 | p.legend.title_text_font_size = '35pt' 239 | p.legend.label_text_font_size = '38pt' 240 | p.title.text_font_size = "40pt" 241 | 242 | plots.append(copy(p)) 243 | 244 | # from joblib import dump 245 | # with open('./plots_tsne_vaule_wp', 'wb') as file: 246 | # dump(plots, file) 247 | # num_columns = max(0, int(embeddings.shape[1] ** (1/2))) 248 | # num_rows = embeddings.shape[1] // num_columns + 1 249 | num_columns = 1 250 | num_rows = 1 251 | output_width = fw * num_columns 252 | output_height = fh * num_rows 253 | grid = gridplot(plots, ncols= num_columns, width=output_width, height=output_height) 254 | 255 | from bokeh.io import export_png, save, export_svg 256 | from pathlib import Path 257 | 258 | res_dir = Path(dataset_full_path).parent / cfg.proj_algorithm / name 259 | os.makedirs(str(res_dir), exist_ok=True) 260 | image_output_path = os.path.join(res_dir, "_".join([cfg.proj_algorithm ,embedding_key, *label_keys]) + ".png") 261 | svg_output_path = os.path.join(res_dir, "_".join([cfg.proj_algorithm, embedding_key, *label_keys]) + ".svg") 262 | html_save_path = os.path.join(res_dir, "_".join([cfg.proj_algorithm, embedding_key, *label_keys]) + ".html") 263 | 264 | print(f"Saving image to: {image_output_path}") 265 | 266 | save(grid, filename=html_save_path) 267 | export_svg(grid, filename=svg_output_path, width=output_width, height=output_height) 268 | export_png(grid, filename=image_output_path , width=output_width , height=output_height) 269 | 270 | print("Done and done ... ") 271 | 272 | if __name__ == "__main__": 273 | tsne() --------------------------------------------------------------------------------