├── config
└── config_vis
│ ├── config.yaml
│ └── vis
│ ├── default.yaml
│ └── resnet18.yaml
├── .gitignore
├── imgs
├── tsne_resnet18_race.png
├── tsne_resnet18_gender.png
└── latentplay_teaser.drawio.png
├── scripts
├── paths.sh
├── download_dlib68.sh
└── download_fairface.sh
├── env.yaml
├── gen_dataset.py
├── LICENSE
├── add_embedding.py
├── utils.py
├── README.md
├── loader_configs.py
└── visualization.py
/config/config_vis/config.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - vis: default
3 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | **/__pycache__
2 | **/multirun
3 | **/models
4 | **/Experiments
5 | **/datasets
--------------------------------------------------------------------------------
/imgs/tsne_resnet18_race.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/parsa-ra/LatentPlayInterface/HEAD/imgs/tsne_resnet18_race.png
--------------------------------------------------------------------------------
/imgs/tsne_resnet18_gender.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/parsa-ra/LatentPlayInterface/HEAD/imgs/tsne_resnet18_gender.png
--------------------------------------------------------------------------------
/imgs/latentplay_teaser.drawio.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/parsa-ra/LatentPlayInterface/HEAD/imgs/latentplay_teaser.drawio.png
--------------------------------------------------------------------------------
/config/config_vis/vis/default.yaml:
--------------------------------------------------------------------------------
1 | name: 'test'
2 | dataset_full_path: "./datasets/"
3 | embedding_key: resnet18
4 | proj_algorithm: tsne
5 | discard: False
6 | max_unique_values: 10
7 | label_keys:
8 | - gender
--------------------------------------------------------------------------------
/scripts/paths.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | LP_FAIRFACE_PATH=/path/to/dataset
4 | LP_MORPH_PATH=/path/to/dataset
5 | LP_RFW_PATH=/path/to/dataset
6 | LP_UTKFACE_PATH=/path/to/dataset
7 |
8 | LP_DLIB_PREDICTOR=/path/to/dlib_predictor
9 |
--------------------------------------------------------------------------------
/config/config_vis/vis/resnet18.yaml:
--------------------------------------------------------------------------------
1 | name: 'test'
2 | dataset_full_path: "datasets/fairface2latentplay_resnet18.hf"
3 | embedding_key: resnet18
4 | proj_algorithm: tsne
5 | discard: False
6 | max_unique_values: 10
7 | label_keys:
8 | - gender
--------------------------------------------------------------------------------
/scripts/download_dlib68.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | mkdir -p models && cd models
4 | wget http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2
5 | bzip2 -d shape_predictor_68_face_landmarks.dat.bz2
6 | cd ..
7 | export LP_DLIB_PREDICTOR="$(pwd)/models/shape_predictor_68_face_landmarks.dat"
--------------------------------------------------------------------------------
/env.yaml:
--------------------------------------------------------------------------------
1 | name: latentplay
2 | channels:
3 | - conda-forge
4 | - pytorch
5 | - nvidia
6 | - pytorch3d
7 | - huggingface
8 | dependencies:
9 | - python=3.10
10 | - dlib
11 | - pytorch
12 | - torchvision
13 | - datasets
14 | - scipy
15 | - hydra-core
16 | - omegaconf
17 | - gdown
18 | - bokeh
19 | - Pillow
20 | - tqdm
21 | - pip:
22 | - python-json-logger
23 | - timm
24 | - umap-learn
--------------------------------------------------------------------------------
/scripts/download_fairface.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | echo "Creating the datasets directory"
4 | mkdir -p datasets/fairface && cd datasets/fairface
5 | echo "Downloading the fairface dataset ... "
6 |
7 | gdown 1Z1RqRo0_JiavaZw2yzZG6WETdZQ8qX86 # From the fairface repository
8 | gdown 1i1L3Yqwaio7YSOCj7ftgk8ZZchPG7dmH
9 | gdown 1wOdja-ezstMEp81tX1a-EYkFebev4h7D
10 |
11 | unzip fairface-img-margin025-trainval.zip
12 |
13 | export LP_FAIRFACE_PATH=$(pwd)/datasets/fairface
14 |
--------------------------------------------------------------------------------
/gen_dataset.py:
--------------------------------------------------------------------------------
1 | from datasets import load_dataset, logging, disable_caching
2 | from argparse import ArgumentParser
3 | from multiprocessing import cpu_count
4 | import os
5 |
6 | parser = ArgumentParser()
7 | parser.add_argument("dataset_name", choices=[
8 | "fairface",
9 | "morph",
10 | "rfw",
11 | "utkface",
12 | ])
13 | parser.add_argument("--output_path", default="datasets")
14 |
15 | args = parser.parse_args()
16 | os.makedirs(args.output_path, exist_ok=True)
17 | output_path = args.output_path
18 |
19 | logging.set_verbosity_info()
20 |
21 | dataset_builder_name = args.dataset_name + "2latentplay"
22 | data = load_dataset(f"./builders/{dataset_builder_name}", num_proc=cpu_count()-1)
23 |
24 | output_path = os.path.join(args.output_path, dataset_builder_name + ".hf")
25 | print(f"Dataset will be saved at: {output_path}")
26 | data.save_to_disk(output_path)
27 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Parsa Rahimi
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/add_embedding.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | from datasets import load_from_disk
4 | from io import BytesIO
5 | from PIL import Image
6 | from traceback import print_exc
7 | import torch
8 | from torchvision.models import resnet18, ResNet18_Weights
9 |
10 | from argparse import ArgumentParser
11 |
12 | parser = ArgumentParser()
13 | parser.add_argument('dataset_path')
14 |
15 | args = parser.parse_args()
16 |
17 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
18 |
19 | dataset = load_from_disk(args.dataset_path)
20 |
21 | aligned_key = "image"
22 |
23 | arch = "resnet18"
24 | weights = ResNet18_Weights.DEFAULT
25 | preprocess = weights.transforms()
26 | model = resnet18()
27 |
28 | model.to(device)
29 |
30 | output_key = arch
31 | output_status_key = "status_" + arch
32 |
33 |
34 | def embedder(sample):
35 | try:
36 | # Loading the aligned image
37 | img = preprocess(Image.open(BytesIO(sample[aligned_key])).convert('RGB'))
38 |
39 | print(img.shape)
40 | img = img.unsqueeze(0)
41 |
42 | img = img.to(device)
43 | feats = model(img)
44 |
45 | feats = feats.detach().to('cpu')
46 |
47 | #print(feats.shape)
48 | sample[output_status_key] = True
49 | sample[output_key] = feats[0,:]
50 |
51 | del img, feats
52 |
53 | except:
54 | print(f"Failed for sample ")
55 | print_exc()
56 |
57 | sample[output_status_key] = False
58 | sample[output_key] = torch.zeros(512, dtype=torch.float32)
59 |
60 | return sample
61 |
62 | dataset = dataset.map(embedder, writer_batch_size=1000)
63 |
64 | dataset.save_to_disk( args.dataset_path[:-3] + f"_{arch}.hf")
65 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import logging.config as config
3 | from logging.handlers import TimedRotatingFileHandler
4 | from pythonjsonlogger.jsonlogger import JsonFormatter
5 | import sys
6 | import os
7 |
8 | def get_env_key_val(name: str, default: bool):
9 | return os.getenv(name) if os.getenv(name) else default
10 |
11 | stdout_loglvl = get_env_key_val("STDOUT_LOG_LEVEL", logging.WARNING)
12 | if not stdout_loglvl in logging._nameToLevel:
13 | stdout_loglvl = logging.DEBUG
14 |
15 | print("Setting stdout loglvl to {}".format(stdout_loglvl))
16 | file_loglvl = get_env_key_val("FILE_LOG_LEVEL", logging.DEBUG)
17 | if not file_loglvl in logging._nameToLevel:
18 | file_loglvl = logging.DEBUG
19 | print("Setting file loglvl to {}".format(file_loglvl))
20 |
21 | base_directory = "logs"
22 | dirname = os.path.dirname(os.path.abspath(__file__))
23 | parent_folder_name = dirname.split('/')[-1]
24 | APP_NAME = parent_folder_name
25 | APP_VERSION = "0.1"
26 |
27 | def get_logger(name: str, add_stdout=True, add_file=True, default_log_level=logging.DEBUG):
28 | logger = logging.getLogger(name)
29 | logger.setLevel(stdout_loglvl)
30 |
31 | fmt = logging.Formatter('%(name)s | %(levelname)s | %(asctime)s | %(filename)s | %(funcName)s | %(lineno)s | %(message)s')
32 | json_fmt = JsonFormatter('%(levelname)s %(message)s %(asctime)s %(name)s %(funcName)s %(lineno)d %(thread)d %(pathname)s', json_ensure_ascii=False)
33 |
34 |
35 | if add_stdout:
36 | stream_handler = logging.StreamHandler(sys.stdout)
37 | stream_handler.setLevel(stdout_loglvl)
38 | stream_handler.setFormatter(fmt)
39 | logger.addHandler(stream_handler)
40 |
41 | if add_file:
42 | # file_handler = logging.FileHandler(f"./logs_{name}.txt", encoding='utf-8')
43 | file_handler = TimedRotatingFileHandler(os.path.join(base_directory ,f"./logs_{name}.json"), when='d', interval=1, backupCount=60)
44 | file_handler.setLevel(file_loglvl)
45 | file_handler.setFormatter(json_fmt)
46 | logger.addHandler(file_handler)
47 |
48 | if not add_stdout and not add_file:
49 | print("You should at least set one of the handlers to True")
50 | sys.exit(0)
51 |
52 | return logger
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | # LatentPlay Interface
3 |
4 | 
5 |
6 | Assume that you have a dataset of images and multiple labels for each of the images, for example, images of human faces and their corresponding labels like the gender, race and age of the person, or images of objects with their corresponding name from discrete set of possible values. Now consider the case that we want to calculate the embedding of each of the images using different algorithms, or store the post-processed versions of the images, the practice now is to store each of these data types in different files or folders (for example an .npy file for embeddings and a separate folder to save the post-processed images and save the path to the folder in the csv file), here we introducing a way to unify all of this storage and reusability issues using our proposed `LatentPlay` interface which tries to simplify this process using the 🤗 [Datasets](https://huggingface.co/docs/datasets/en/index) library.
7 |
8 |
9 | ## TLDR:
10 | A **practice** to handle multi-modal datasets in a unified way.
11 |
12 | This folder contains scripts to convert various datasets to my unified latentplay interface based on `huggingfaces dataset` for facilitating further experimentation with minimal code changes. You can also find some sample usages in next sections.
13 |
14 | Currently it supports following datasets:
15 | - [FairFace](https://github.com/joojs/fairface)
16 | - [MORPH](https://ieeexplore.ieee.org/document/1613043)
17 | - [RFW](http://www.whdeng.cn/RFW/index.html)
18 | - [UTKFace](https://susanqq.github.io/UTKFace/)
19 |
20 | Pull requests are welcomed 🤝🏻.
21 |
22 | Given the path to the original datasets as they originally published, it will generate the `.hf` folder with the datasets packed into a binary files.
23 |
24 | **Note**: The scripts designed for the original distribution of datasets (in terms of folder hierarchy) by the corresponding authors.
25 |
26 | Later you can simple load the dataset in your python script like:
27 | ```python
28 | from datasets import load_from_disk
29 | dataset = load_from_disk('path/to/your/dataset.hf')
30 | ```
31 |
32 |
33 | ## Getting Started
34 | Creating conda environment
35 | ```bash
36 | mamba/conda/micromamba create -f ./env.yaml
37 | mamba activate latentplay
38 | ```
39 |
40 |
41 | ## Making LatentPlay version of Fairface [OPTIONAL]
42 |
43 | Downloading the [FairFace](https://drive.google.com/file/d/1Z1RqRo0_JiavaZw2yzZG6WETdZQ8qX86/view) dataset.
44 |
45 | We provide a helper script for that, just run:
46 | ```bash
47 | chmod +x ./scripts/download_fairface.sh
48 | download_fairface.sh
49 | ```
50 |
51 | I also provide the preprocessed version of this dataset you can download it from release.
52 |
53 |
54 | # Playing around with the dataset
55 | Now that you have a dataset and all of its metadata you can play around with it, for example, adding new column to the dataset in terms of the embeddings of the images (yes you can add vectors to the datasets which already contains images, meta data in different formats!!) so its make it easy for us to manage our experiments.
56 |
57 | ## Sample#1: Adding Embedding as a new column
58 | Following script will extract the embedding of the `image` column in our dataset and save it in new `resnet18` column.
59 | Optionally you can download the dataset with pre-calculated embedding from release.
60 |
61 | ```bash
62 | python3 add_embedding.py datasets/fairface2latentplay.hf
63 | ```
64 |
65 | The new dataset with the new embedding column will be saved next to the original one.
66 |
67 |
68 | ## Sample#2: Visualization of the dataset using t-SNE
69 | Now that we have both embedding, and their corresponding labels in a **unified dataset**, lets see how we can draw a `t-SNE` plot from embeddings according to different attribute groups (e.g. gender, race, ...).
70 |
71 | ```bash
72 | python3 visualization.py --multirun vis=resnet18
73 | ```
74 | The plots will be saved next to the dataset path.
75 |
76 | Or you can do the same for the `race` column in the dataset.
77 | ```bash
78 | python3 visualization.py --multirun vis=resnet18 vis.label_keys=['race']
79 | ```
80 |
81 | | Gender | Race |
82 | |--------------------------------------|--------------------------------------|
83 | |
|
|
84 |
85 |
86 | We can also draw the plots using `umap`. in this case you can run:
87 | ```bash
88 | python3 visualization.py --multirun vis=resnet18 vis.proj_algorithm=umap
89 | ```
90 |
91 | # TODO:
92 | Add more sample use cases ...
93 |
94 | # Disclaimer
95 | This repository is part of our full paper code release originally published [here](https://gitlab.idiap.ch/biometric/sg_latent_modeling).
96 |
97 | # Citation
98 | If you found this framework useful please consider citing our paper.
99 |
100 | ```bibtex
101 | @inproceedings{rahimi2023toward,
102 | title={Toward responsible face datasets: modeling the distribution of a disentangled latent space for sampling face images from demographic groups},
103 | author={Rahimi, Parsa and Ecabert, Christophe and Marcel, S{\'e}bastien},
104 | booktitle={2023 IEEE International Joint Conference on Biometrics (IJCB)},
105 | pages={1--11},
106 | year={2023},
107 | organization={IEEE}
108 | }
109 | ```
--------------------------------------------------------------------------------
/loader_configs.py:
--------------------------------------------------------------------------------
1 | import datasets
2 | import io
3 | import os
4 | from glob import glob
5 | from tqdm import tqdm
6 | import numpy as np
7 | import PIL
8 | from PIL import Image
9 | import scipy
10 | import dlib
11 |
12 | from datasets import Sequence
13 |
14 | preprocessing_target_size = 256
15 |
16 | # landmark detector
17 | dlib_landmark_detector_path = os.getenv("LP_DLIB_PREDICTOR")
18 |
19 | naming_splitter = "~"
20 |
21 | ######################## Specs ###########################
22 |
23 | male = "male"
24 | female = "female"
25 |
26 | race_map = {
27 | "asian": "asian",
28 | "black": "black",
29 | "african": "black",
30 | "caucasian": "white",
31 | "white": "white",
32 | "middleeastern": "middleeastern",
33 | "indian": "indian",
34 | "hispanic": "hispanic",
35 | "unknown": "unknown",
36 |
37 | "O": "unkown",
38 | "A": "asian",
39 | "B": "black",
40 | "H": "hispanic",
41 | "W": "white",
42 |
43 | "East Asian": "asian", # TODO
44 | "Indian": "indian",
45 | "Black": "black",
46 | "Middle Eastern": "middleeastern",
47 | "White": "white",
48 | "Latino_Hispanic": "hispanic",
49 | "Southeast Asian": "asian", #TODO
50 |
51 | # From https://susanqq.github.io/UTKFace/
52 | # utkface
53 | "0": "white",
54 | "1": "black",
55 | "2": "asian",
56 | "3": "indian",
57 | "4": "unknown"
58 | }
59 |
60 | gender_map = {
61 | "m": "male",
62 | "f": "female",
63 | "M": "male",
64 | "male": "male",
65 | "man": "male",
66 | "F": "female",
67 | "women": "female",
68 | "female": "female",
69 | "unknown": "unknown",
70 |
71 | "Male": "male",
72 | "Female": "female",
73 |
74 | # From https://susanqq.github.io/UTKFace/
75 | # utkface
76 | "0": "male",
77 | "1": "female",
78 | }
79 |
80 |
81 | # TODO: add some kind of versioning of sort
82 |
83 | dataset_features = {
84 | "person_id": datasets.Value("string"), # Unique Id that will be used to identify if the persons are from the same identity or not ...
85 | "image": datasets.Value("binary"),
86 |
87 | "dlib_align_status": datasets.Value("bool"),
88 | "image_dlib_aligned": datasets.Value("binary"),
89 |
90 |
91 | "gender": datasets.Value("string"),
92 | "race": datasets.Value("string"),
93 | "age": datasets.Value("string"),
94 | "human": datasets.Value("bool"), # Used for human non-human stuff ...
95 | }
96 |
97 |
98 | def pre_process_images(raw_image_path, output_path, predictor):
99 | current_directory = os.getcwd()
100 | print(current_directory)
101 |
102 | aligned_images = []
103 | try:
104 | aligned_image = align_face(filepath=raw_image_path,
105 | predictor=predictor, output_size=preprocessing_target_size)
106 | aligned_images.append(aligned_image)
107 | except Exception as e:
108 | print(e)
109 |
110 | os.makedirs(output_path, exist_ok=True)
111 | images_names = [raw_image_path.split('/')[-1]]
112 | for image, name in zip(aligned_images, images_names):
113 | # Name without extensions
114 | real_name = name.split('.')[0]
115 | image.save(f'{output_path}/{real_name}.jpeg')
116 |
117 | os.chdir(current_directory)
118 |
119 | ## Borrowed from Insightfaces repository
120 | def get_landmark(filepath, predictor):
121 | """get landmark with dlib
122 | :return: np.array shape=(68, 2)
123 | """
124 | detector = dlib.get_frontal_face_detector()
125 |
126 | img = dlib.load_rgb_image(filepath)
127 | dets = detector(img, 1)
128 |
129 | for k, d in enumerate(dets):
130 | shape = predictor(img, d)
131 |
132 | t = list(shape.parts())
133 | a = []
134 | for tt in t:
135 | a.append([tt.x, tt.y])
136 | lm = np.array(a)
137 | return lm
138 |
139 |
140 | def align_face(filepath, predictor, output_size):
141 | """
142 | :param filepath: str
143 | :return: PIL Image
144 | """
145 |
146 | lm = get_landmark(filepath, predictor)
147 |
148 | lm_chin = lm[0: 17] # left-right
149 | lm_eyebrow_left = lm[17: 22] # left-right
150 | lm_eyebrow_right = lm[22: 27] # left-right
151 | lm_nose = lm[27: 31] # top-down
152 | lm_nostrils = lm[31: 36] # top-down
153 | lm_eye_left = lm[36: 42] # left-clockwise
154 | lm_eye_right = lm[42: 48] # left-clockwise
155 | lm_mouth_outer = lm[48: 60] # left-clockwise
156 | lm_mouth_inner = lm[60: 68] # left-clockwise
157 |
158 | # Calculate auxiliary vectors.
159 | eye_left = np.mean(lm_eye_left, axis=0)
160 | eye_right = np.mean(lm_eye_right, axis=0)
161 | eye_avg = (eye_left + eye_right) * 0.5
162 | eye_to_eye = eye_right - eye_left
163 | mouth_left = lm_mouth_outer[0]
164 | mouth_right = lm_mouth_outer[6]
165 | mouth_avg = (mouth_left + mouth_right) * 0.5
166 | eye_to_mouth = mouth_avg - eye_avg
167 |
168 | # Choose oriented crop rectangle.
169 | x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
170 | x /= np.hypot(*x)
171 | x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8)
172 | y = np.flipud(x) * [-1, 1]
173 | c = eye_avg + eye_to_mouth * 0.1
174 | quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
175 | qsize = np.hypot(*x) * 2
176 |
177 | # read image
178 | img = PIL.Image.open(filepath)
179 |
180 | transform_size = output_size
181 | enable_padding = True
182 |
183 | # Shrink.
184 | shrink = int(np.floor(qsize / output_size * 0.5))
185 | if shrink > 1:
186 | rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink)))
187 | img = img.resize(rsize, PIL.Image.ANTIALIAS)
188 | quad /= shrink
189 | qsize /= shrink
190 |
191 | # Crop.
192 | border = max(int(np.rint(qsize * 0.1)), 3)
193 | crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
194 | int(np.ceil(max(quad[:, 1]))))
195 | crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]),
196 | min(crop[3] + border, img.size[1]))
197 | if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]:
198 | img = img.crop(crop)
199 | quad -= crop[0:2]
200 |
201 | # Pad.
202 | pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
203 | int(np.ceil(max(quad[:, 1]))))
204 | pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0),
205 | max(pad[3] - img.size[1] + border, 0))
206 | if enable_padding and max(pad) > border - 4:
207 | pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
208 | img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
209 | h, w, _ = img.shape
210 | y, x, _ = np.ogrid[:h, :w, :1]
211 | mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w - 1 - x) / pad[2]),
212 | 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h - 1 - y) / pad[3]))
213 | blur = qsize * 0.02
214 | img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
215 | img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0)
216 | img = PIL.Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB')
217 | quad += pad[:2]
218 |
219 | # Transform.
220 | img = img.transform((transform_size, transform_size), PIL.Image.QUAD, (quad + 0.5).flatten(), PIL.Image.BICUBIC)
221 | if output_size < transform_size:
222 | img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS)
223 |
224 | # Return aligned image.
225 | return img
--------------------------------------------------------------------------------
/visualization.py:
--------------------------------------------------------------------------------
1 | from enum import unique
2 | from utils import get_logger
3 | from sklearn.manifold import TSNE
4 | from bokeh.plotting import figure, show
5 | from bokeh.transform import factor_cmap, factor_mark
6 | from bokeh.layouts import gridplot
7 | from bokeh.models import ColumnDataSource
8 | from typing import List
9 | from copy import copy, deepcopy
10 | from joblib import dump, load
11 | from glob import glob
12 | import numpy as np
13 | import pandas as pd
14 | from datasets import load_from_disk, Dataset
15 | from copy import copy
16 | import matplotlib.pyplot as plt
17 | import numpy as np
18 | import matplotlib
19 | import hydra
20 | from omegaconf import DictConfig, OmegaConf
21 | from multiprocessing import cpu_count
22 |
23 | import os
24 |
25 | logger = get_logger("Manifold", True, False)
26 |
27 | experiments_base_path = "./Experiments"
28 | projection_base_path = os.path.join(experiments_base_path, 'projection')
29 |
30 | fw, fh = 1200, 1200
31 |
32 | np.random.seed(13)
33 |
34 | def generate_distinct_colors(n):
35 | colors = plt.cm.get_cmap("hsv", n)
36 | return [matplotlib.colors.rgb2hex(colors(i)[:3]) for i in range(n)]
37 |
38 |
39 | MARKERS = ['hex', 'circle_x', 'triangle', 'square', 'circle', 'dot', 'asterisk', 'diamond']
40 | COLORS = generate_distinct_colors(20)
41 | np.random.shuffle(COLORS)
42 | print(COLORS)
43 |
44 | change_label_name = False
45 | embeddings = None
46 | metas = None
47 | FORCE_RERUN = False
48 |
49 | print("Creating projection base path ... ")
50 | os.makedirs(projection_base_path, exist_ok=True)
51 |
52 | colormap = dict()
53 | markermap = dict()
54 |
55 | @hydra.main(version_base=None, config_path="config/config_vis", config_name="config.yaml")
56 | def tsne(cfg: DictConfig):
57 | cfg = cfg.vis
58 |
59 | name = cfg.name
60 |
61 | dataset_full_path = cfg.dataset_full_path
62 | print(dataset_full_path)
63 |
64 | embedding_key = cfg.embedding_key
65 | label_keys = cfg.label_keys
66 |
67 | assert len(label_keys) <= 2 , "Right know for tnsne we can only show two label at a time"
68 | label_map = dict()
69 |
70 | label_precedence = ["colormap", "markermap"]
71 | value_precedence = [ "COLORS", "MARKERS" ]
72 | values = ["color", "marker"]
73 |
74 | dataset: Dataset = load_from_disk(dataset_full_path)
75 | dataset_keys = list()
76 | dataset_keys.append(embedding_key)
77 | dataset_keys.extend(label_keys)
78 |
79 | #prefix_name= dataset['train']._fingerprint # Inferfrom name or fingerprint
80 | prefix_name = f"{cfg.proj_algorithm.upper()}s"
81 | exp_name = "-".join([prefix_name, embedding_key, *label_keys, dataset['train']._fingerprint])
82 |
83 | status_prefix = "status"
84 | for key in dataset:
85 | status_key_name = status_prefix + "_"+embedding_key
86 | dataset_check = dataset[key].select_columns(column_names=[embedding_key, status_key_name])
87 | dataset_check = dataset_check.to_pandas()
88 | print(dataset_check[status_key_name].value_counts().to_dict())
89 |
90 |
91 | if cfg.discard:
92 | print(f"Filtering enteries with their {cfg.discard + embedding_key} set to true ...")
93 | dataset = dataset.filter(lambda entry: entry[ cfg.discard + embedding_key],
94 | num_proc=min(1,cpu_count()//2)) # There is some problem with the datasets' filtering process with multiple processes
95 | print("Filtering finished.")
96 |
97 | dataset.set_format("numpy", columns=dataset_keys)
98 |
99 | if 'test' not in dataset and 'train' in dataset:
100 | logger.warn('Test shard is not in the dataset, will split the dataset automatically')
101 | dataset = dataset['train'].train_test_split(test_size=0.1)
102 | else:
103 | logger.error('Unknown split names, the dataset should either contain an `train` or `test`split')
104 |
105 | train_embeddings = dataset['train'][embedding_key]
106 | test_embeddings = dataset['test'][embedding_key]
107 |
108 | test_labels = list()
109 |
110 | for label_key in label_keys:
111 | test_labels.append(dataset['test'][label_key])
112 |
113 | unique_keys: List[List] = list()
114 |
115 | for label_idx, label_key in enumerate(label_keys):
116 | unique_keys.append(list())
117 |
118 | cur_labels = test_labels[label_idx]
119 | for cur_label in cur_labels:
120 | if cur_label not in unique_keys[-1]:
121 | print(f"Adding {cur_label} for {label_key}")
122 | unique_keys[-1].append(cur_label)
123 |
124 |
125 | # import sys
126 | print("\n\n\n", unique_keys)
127 | max_unique_key_len = 0
128 | for unique_key in unique_keys:
129 | max_unique_key_len = max(len(unique_key), max_unique_key_len)
130 | print(max_unique_key_len, "\n\n\n")
131 | # sys.exit(0)
132 |
133 | plots = list()
134 |
135 | title="-".join([prefix_name, embedding_key, *label_keys])
136 | segmented_title = ""
137 | cur_pos = 0
138 | pos_step = 25
139 | while cur_pos < len(title):
140 | segmented_title += title[cur_pos: cur_pos+pos_step] + "\n"
141 | cur_pos += pos_step
142 |
143 |
144 | print(f"Original Title was: {title}\nSegmented title is: {segmented_title}\n")
145 | p = figure(title = segmented_title, frame_width=fw, frame_height=fh, background_fill_color="#fafafa")
146 |
147 | cur_embeddings = train_embeddings
148 | projector_name = exp_name
149 | projector_path = os.path.join(projection_base_path ,f"{projector_name}_manifoldlearning.pickle")
150 | projector_sklearn_api = None
151 | if os.path.exists(projector_path) and not FORCE_RERUN:
152 | logger.info(f"It seems that the projector file for current config already exists in {projector_path}, ... Loading it from disk instead of re-running the tsne")
153 | with open(projector_path, 'rb') as file:
154 | projector_sklearn_api = load(file)
155 | else:
156 | logger.info(f"The tsne file on disk cannot be found at {projector_path}, running the tsne training ... ")
157 |
158 | if cfg.proj_algorithm == "tsne":
159 | logger.info(f"Setting projection algorithm to TSNE")
160 | projector_sklearn_api = TSNE(
161 | n_components=2,
162 | learning_rate="auto",
163 | perplexity=50,
164 | n_iter=1000,
165 | n_jobs=-1,
166 | init="pca"
167 | )
168 | elif cfg.proj_algorithm == "umap":
169 | logger.info(f"Setting projection algorithm to UMAP")
170 | import umap
171 | projector_sklearn_api = umap.UMAP()
172 | else:
173 | logger.error(f"Unsupported projection algorithm {cfg.proj_algorithm}")
174 | raise ValueError()
175 |
176 | projector_sklearn_api.fit_transform(train_embeddings)
177 |
178 | with open(projector_path, 'wb') as file:
179 | dump(projector_sklearn_api, file)
180 |
181 | colormap = dict()
182 | markermap = dict()
183 |
184 | for label_idx, label_key in enumerate(label_keys):
185 | print(label_key, label_idx)
186 | for idx, entry in enumerate(unique_keys[label_idx]):
187 | print(entry, idx)
188 | print(label_precedence[label_idx], value_precedence[label_idx])
189 |
190 | #print(globals()[label_precedence[label_idx]])
191 | #print(globals()[value_precedence[label_idx]])
192 |
193 | globals()[label_precedence[label_idx]][entry] = globals()[value_precedence[label_idx]][idx]
194 |
195 | colormap[entry] = COLORS[idx]
196 |
197 | # colormap[entry] = COLORS[idx]
198 | # markermap[entry] = MARKERS[idx]
199 |
200 | #print("\n\n\nColorMap", colormap)
201 |
202 | # if change_label_name:
203 | # mapping_label_name = dict()
204 | # for key in unique_keys:
205 | # alternative_name = str(input(f"Enter name alternative name for the {key}: "))
206 | # mapping_label_name[key] = alternative_name
207 | # for idx, meta in enumerate(metas):
208 | # metas[idx] = mapping_label_name[meta]
209 |
210 | projected = projector_sklearn_api.fit_transform(test_embeddings)
211 |
212 | #print(list(colormap.keys()))
213 | source = ColumnDataSource(dict(
214 | x=projected[:,0],
215 | y=projected[:,1],
216 | # marker = [ markermap[meta] for meta in test_labels[1] ],
217 | color = [ colormap[meta] for meta in test_labels[0] ],
218 | label = test_labels[0]
219 | )
220 | )
221 |
222 | print("Plotting scatter plot ... ")
223 | p.scatter(
224 | x = 'x',
225 | y = 'y',
226 | source = source,
227 | size = 14,
228 | color ='color',
229 | #marker ='marker',
230 | legend_group ='label',
231 | fill_alpha =0.5,
232 | )
233 | p.legend.location = "top_left"
234 | p.legend.title = "-".join([prefix_name, embedding_key, *label_keys])
235 | print("plotting scatter done")
236 |
237 | p.legend.background_fill_alpha = 0.3
238 | p.legend.title_text_font_size = '35pt'
239 | p.legend.label_text_font_size = '38pt'
240 | p.title.text_font_size = "40pt"
241 |
242 | plots.append(copy(p))
243 |
244 | # from joblib import dump
245 | # with open('./plots_tsne_vaule_wp', 'wb') as file:
246 | # dump(plots, file)
247 | # num_columns = max(0, int(embeddings.shape[1] ** (1/2)))
248 | # num_rows = embeddings.shape[1] // num_columns + 1
249 | num_columns = 1
250 | num_rows = 1
251 | output_width = fw * num_columns
252 | output_height = fh * num_rows
253 | grid = gridplot(plots, ncols= num_columns, width=output_width, height=output_height)
254 |
255 | from bokeh.io import export_png, save, export_svg
256 | from pathlib import Path
257 |
258 | res_dir = Path(dataset_full_path).parent / cfg.proj_algorithm / name
259 | os.makedirs(str(res_dir), exist_ok=True)
260 | image_output_path = os.path.join(res_dir, "_".join([cfg.proj_algorithm ,embedding_key, *label_keys]) + ".png")
261 | svg_output_path = os.path.join(res_dir, "_".join([cfg.proj_algorithm, embedding_key, *label_keys]) + ".svg")
262 | html_save_path = os.path.join(res_dir, "_".join([cfg.proj_algorithm, embedding_key, *label_keys]) + ".html")
263 |
264 | print(f"Saving image to: {image_output_path}")
265 |
266 | save(grid, filename=html_save_path)
267 | export_svg(grid, filename=svg_output_path, width=output_width, height=output_height)
268 | export_png(grid, filename=image_output_path , width=output_width , height=output_height)
269 |
270 | print("Done and done ... ")
271 |
272 | if __name__ == "__main__":
273 | tsne()
--------------------------------------------------------------------------------