├── .gitignore ├── README.md ├── SVM_TODO.py ├── _0_prep_dataset.py ├── _1_embed_with_CLIP.py ├── _2_remove_duplicates.py ├── _3_label_images.py ├── _4_train_model.py ├── _5_predict_labels.py ├── _6_create_subset.py ├── investigate_embedding.py ├── models └── single_crop_regression_9.4k_imgs_80_epochs.pth ├── predict_simple.py ├── tools ├── find_similar_imgs.py ├── fix_img_dir.py └── move_subset_of_files.py └── utils ├── embedder.py ├── image_features.py ├── merge_datasets.py ├── nn_model.py └── train_latent_regressor.py /.gitignore: -------------------------------------------------------------------------------- 1 | #models/* 2 | CLIP_cmds.txt 3 | cmds.txt 4 | losses.png 5 | __pycache__/ 6 | utils/models/* 7 | utils/*.png 8 | test_set_predictions.png 9 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CLIP assisted data labeling 2 | Python toolkit to quickly label/filter lots of images using CLIP embeddings + active learning. 3 | Main use-case is to filter large image datasets that contain lots of bad images you don't want to train on. 4 | This kit is meant to be used on a localhost linux desktop with display (labeling is done with simple OpenCV display). 5 | 6 | ## Overview: 7 | 0. (optional, but recommended) Create unique uuid's for each img in the root_dir 8 | 1. Embed all the images in the database using CLIP 9 | 2. (optional) Remove potential duplicate images based on a cosine-similarity threshold in CLIP-space 10 | 3. Manually label a few images (10 minutes of labeling is usually sufficient to start) 11 | Labeling supports ordering the images in several different ways: 12 | - by uuid (= random) 13 | - best predicted first 14 | - worst predicted first 15 | - median predicted first (start labeling where there's most uncertainty) 16 | - diversity sort (tries to start with an as diverse as possible subset of the data) 17 | 4. Train a NN regressor/classifier on the current database (CLIP-embedding --> label) 18 | 5. Predict the labels for all the unlabeled images 19 | --> Go back to (3) and iterate until satisfied with the predicted labels 20 | 6. Filter your dataset using the predicted labels 21 | 22 | ## Example usage: 23 | 24 | ``` 25 | export ROOT_DIR="path_to_your_img_dir" 26 | python _0_prep_dataset.py --root_dir $ROOT_DIR --mode rename 27 | python _1_embed_with_CLIP.py --root_dir $ROOT_DIR 28 | ``` 29 | Optional Step: 30 | ``` 31 | python _2_remove_duplicates.py --root_dir $ROOT_DIR --mode move 32 | ``` 33 | ``` 34 | python _3_label_images.py --root_dir $ROOT_DIR 35 | python _4_train_model.py --train_data_dir path_to_labeled_root_dir --train_data_names labeled_subfolder_name_01 labeled_subfolder_name_02 --model_name model_01 --test_fraction 0.20 36 | python _5_predict_labels.py --root_dir $ROOT_DIR --model_file name_of_trained_model_01 --copy_imgs_fraction 1.0 --batch_size 6 37 | ``` 38 | Finally, apply the trained model to a new dataset: 39 | ``` 40 | python _1_embed_with_CLIP.py --root_dir path_to_large_unlabeled_img_dir 41 | python _6_create_subset.py --input_dir path_to_large_unlabeled_img_dir --min_score 0.4 --extensions .jpg .json 42 | ``` 43 | 44 | ## Detailed walkthrough: 45 | 46 | ### 0. Preprocessing your data 47 | Recommended to use the --convert_imgs_to_jpg flag to auto-convert all your images to .jpg 48 | Huge images will also be auto-resized (control max_res with --max_n_pixels flag) 49 | Metadata files (such as .txt prompt files or .json files) that have the same basename (but different extension) as the image files can remain and will be handled correctly. 50 | 51 | In all following scripts, the root_dir is the main directory where your training images live. 52 | Most scripts should also work if this root_dir has subfolders with eg different subsets of the data. 53 | 54 | ### _0_prep_dataset.py 55 | The labels are kept in a .csv file with a unique identifier linking to each image. 56 | To be sure to avoid name clashes, it is highly recommended to rename each img (and any metadata files with the same name) with a unique uuid 57 | 58 | ### _1_embed_with_CLIP.py 59 | Specify a specific CLIP model name and a batch size and embed all images using CLIP (embeddings are stored on disk as .pt files) 60 | For each image, 4 crops are taken: 61 | - square crop at the centre of the image 62 | - padded image to a full square 63 | - subcrop1 (to detect blurry images and zoomed-in details) 64 | - subcrop2 (to detect blurry images and zoomed-in details) 65 | 66 | Additionally, some manually engineered img features are also computed and saved to disk. 67 | 68 | ### _2_remove_duplicates.py 69 | Specify a cosine-similarity threshold and remove duplicate images from the dataset. 70 | This currently only works on max ~10k imgs at a time (due to the quadratic memory requirement of the all-to-all distance matrix) 71 | but the script randomly shuffles all imgs, so if you run this a few times that should get most of the duplicates! 72 | 73 | ### _3_label_images.py 74 | This script currently only works on a single image folder with no subfolders! 75 | Super basic labeling interface using opencv that support re-ordering the images based on predicted labels 76 | Label an image using they numkeys [0-9] on the keyboard 77 | Go forward and backwards using the arrow keys <-- / --> 78 | If a --filename--.txt file is found, the text in it will be displayed as prompt for the img. 79 | 80 | ### _4_train_model.py 81 | Train a simple 3-layer FC-neural network with ReLu's based on the flattened CLIP-crop embeddings to regress / classify the image labels 82 | Flow: 83 | - first optimize hyperparameters using eg `--test_fraction 0.15 --n_epochs 100` and `--dont_save` 84 | - look at the train/test loss curves to figure out the best amount of epochs to train 85 | - finally do a final training run using all the data using `--test_fraction 0.0` 86 | 87 | ### _5_predict_labels.py 88 | Predict the labels for the entire image dataset using the trained model: `--model_file name_of_your_model` 89 | `--copy_named_imgs_fraction` can be used to see a sneak peak of labeled results in a tmp output_directory 90 | 91 | ### _6_create_subset.py 92 | Finally, use the predicted labels to copy a subset of the dataset to an output folder. 93 | (Currently does not yet work on folders with subdirs) 94 | 95 | ## TODO 96 | - add requirements.txt 97 | - add a keyboard mapping class to the labeling code that has different mappings depending on which OS is running the code (currently, the keys are hardcoded for Ubuntu 20.04) 98 | - CLIP features are great for semantic labeling/filtering, but tend to ignore low-level details like texture sharpness, pixel grain and bluriness. 99 | The pipeline can probably be improved by adding additional features (lpips, vgg, ...) 100 | - Currently the scoring model is just a heavily regularized 3-layer FC-neural network. It's likely that adding a more linear component (eg SVM) could make the predictions more robust 101 | - The labeling tool currently only supports numerical labels and the pipeline is built for regression. This could be easily extended to class labels + classification. 102 | 103 | -------------------------------------------------------------------------------- /SVM_TODO.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | np.random.seed(42) 3 | 4 | """ 5 | 6 | from Karpathy: 7 | https://twitter.com/karpathy/status/1647025230546886658 8 | 9 | Q: Can this approach for finding "similar" embeddings also be transformed 10 | to make a better classifier / regressor in high dimensional spaces? 11 | 12 | """ 13 | 14 | dim = 768 15 | n = 1000 16 | 17 | embeddings = np.random.randn(n, dim) # n documents, dim-dimensional embeddings 18 | embeddings = embeddings / np.sqrt((embeddings**2).sum(1, keepdims=True)) # L2 normalize the rows, as is common 19 | 20 | query = np.random.randn(dim) # the query vector 21 | query = query / np.sqrt((query**2).sum()) 22 | 23 | # Tired: use kNN 24 | similarities = embeddings.dot(query) 25 | sorted_ix = np.argsort(-similarities) 26 | print("top 10 results:") 27 | for k in sorted_ix[:10]: 28 | print(f"row {k}, similarity {similarities[k]}") 29 | 30 | 31 | # Wired: use an SVM 32 | from sklearn import svm 33 | 34 | # create the "Dataset" 35 | x = np.concatenate([query[None,...], embeddings]) # x is (1001, 1536) array, with query now as the first row 36 | y = np.zeros(n+1) 37 | y[0] = 1 # we have a single positive example, mark it as such 38 | 39 | # train SVM 40 | # docs: https://scikit-learn.org/stable/modules/generated/sklearn.svm.LinearSVC.html 41 | clf = svm.LinearSVC(class_weight='balanced', verbose=False, max_iter=10000, tol=1e-6, C=0.1) 42 | clf.fit(x, y) # train 43 | 44 | # infer on whatever data you wish, e.g. the original data 45 | similarities = clf.decision_function(x) 46 | sorted_ix = np.argsort(-similarities) 47 | print("\nSVM:") 48 | print("top 10 results:") 49 | for k in sorted_ix[:10]: 50 | print(f"row {k}, similarity {similarities[k]}") -------------------------------------------------------------------------------- /_0_prep_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import random 4 | import uuid 5 | import argparse 6 | from tqdm import tqdm 7 | from PIL import Image 8 | 9 | all_img_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif', '.webp', '.JPEG', '.JPG', '.PNG', '.BMP', '.TIFF', '.TIF', '.WEBP'] 10 | 11 | def process_file(orig_path, new_path, args): 12 | """ 13 | Given an orig_path and new_path: 14 | 1. soft-load with PIL to check if the resolution is within bounds 15 | 2. Optionally downsize the image 16 | 3. Convert to jpg if necessary 17 | 4. Rename or copy the file to the new_path 18 | """ 19 | 20 | os.makedirs(os.path.dirname(new_path), exist_ok=True) 21 | file_extension = os.path.splitext(orig_path)[1] 22 | 23 | is_image = file_extension in all_img_extensions 24 | converted, resized = 0, 0 25 | 26 | if is_image: 27 | img = Image.open(orig_path) 28 | width, height = img.size 29 | if (width * height) > args.max_n_pixels: 30 | new_width = int(width * args.max_n_pixels / (width * height)) 31 | new_height = int(height * args.max_n_pixels / (width * height)) 32 | img = img.resize((new_width, new_height), Image.ANTIALIAS) 33 | if args.convert_imgs_to_jpg: 34 | new_path = os.path.splitext(new_path)[0] + '.jpg' 35 | img.save(new_path, quality=95) 36 | resized = 1 37 | 38 | if args.convert_imgs_to_jpg and not resized: 39 | if file_extension != '.jpg': 40 | new_path = os.path.splitext(new_path)[0] + '.jpg' 41 | img = Image.open(orig_path).convert("RGB") 42 | img.save(new_path, quality=95) 43 | os.remove(orig_path) 44 | converted = 1 45 | 46 | if not is_image or (not resized and not converted): 47 | if args.mode == 'rename': 48 | os.rename(orig_path, new_path) 49 | elif args.mode == 'copy': 50 | shutil.copy(orig_path, new_path) 51 | 52 | return converted, resized 53 | 54 | from natsort import natsorted, ns 55 | def nautilus_sort(filenames): 56 | # Sort filenames naturally and case-insensitively 57 | return natsorted(filenames, alg=ns.IGNORECASE) 58 | 59 | 60 | def prep_dataset_directory(args): 61 | 62 | ''' 63 | Rename all the files in the root_dir with a unique string identifier 64 | Optionally: 65 | - convert imgs to jpg 66 | - downsize imgs if needed 67 | 68 | ''' 69 | 70 | os.makedirs(args.output_dir, exist_ok=True) 71 | renamed_counter, converted_counter, resized_counter, skipped = 0, 0, 0, 0 72 | print_verb = "Copied" if args.mode == 'copy' else "Renamed" 73 | 74 | for subdir, dirs, files in os.walk(args.root_dir): 75 | print(f"Parsing {subdir}, subdirs: {dirs}, n_files: {len(files)}..") 76 | 77 | # Walk through this directory in alphabetical order: 78 | files = nautilus_sort(files) 79 | 80 | # Get all the unique filenames (without the extension) and store a list of present extensions for each one: 81 | unique_filenames = {} 82 | for file in files: 83 | filename, file_extension = os.path.splitext(file) 84 | if filename not in unique_filenames: 85 | unique_filenames[filename] = [] 86 | unique_filenames[filename].append(file_extension) 87 | 88 | # create sorted, but random uuids: 89 | uuids = nautilus_sort([str(uuid.uuid4().hex) for _ in range(len(unique_filenames.keys()))]) 90 | 91 | if args.shuffle_file_order: 92 | uuids = random.shuffle(uuids) 93 | 94 | for i, filename in tqdm(enumerate(unique_filenames.keys())): 95 | extension_list = unique_filenames[filename] 96 | 97 | for ext in extension_list: 98 | new_folder = subdir.replace(args.root_dir, args.output_dir) 99 | orig_filename = os.path.join(subdir, filename + ext) 100 | new_filename = os.path.join(new_folder, uuids[i] + ext) 101 | 102 | try: 103 | converted, resized = process_file(orig_filename, new_filename, args) 104 | renamed_counter += 1 105 | converted_counter += converted 106 | resized_counter += resized 107 | except Exception as e: 108 | print(f"Error on {orig_filename}: {e}") 109 | skipped += 1 110 | continue 111 | 112 | print(f"{print_verb} {renamed_counter} files (converted {converted_counter}, resized {resized_counter}), skipped {skipped}") 113 | 114 | if __name__ == "__main__": 115 | """ 116 | This script renames all the files in the root_dir with a unique string identifier, 117 | it also optionally converts all images to jpg and downsizes them if they are very large. 118 | """ 119 | parser = argparse.ArgumentParser() 120 | parser.add_argument('--root_dir', type=str, help='Root directory of the dataset folder') 121 | parser.add_argument('--output_dir', type=str, default = None, help='Output directory') 122 | parser.add_argument('--mode', type=str, default='copy', help='Modes: rename (in place) or copy') 123 | parser.add_argument('--max_n_pixels', type=int, default=2048*2048, help='Resize when an img is larger than this') 124 | parser.add_argument('--convert_imgs_to_jpg', action='store_true', help='Convert all imgs to .jpg (default: False)') 125 | parser.add_argument('--shuffle_file_order', action='store_true', help='Randomly shuffle the alphabetical ordering of imgs (default: False)') 126 | args = parser.parse_args() 127 | 128 | if args.mode == 'copy' and args.output_dir is None: 129 | raise ValueError("Output directory must be specified when mode is 'copy'") 130 | 131 | if args.output_dir is None: 132 | args.output_dir = args.root_dir 133 | args.mode = 'rename' 134 | 135 | if args.mode == 'rename': 136 | print("####### WARNING #######") 137 | print(f"you are about to rename / resize all the files inside {args.root_dir}, are you sure you want to do this?") 138 | answer = input("Type 'yes' to continue: ") 139 | if answer != 'yes': 140 | raise ValueError("Aborted") 141 | 142 | prep_dataset_directory(args) -------------------------------------------------------------------------------- /_1_embed_with_CLIP.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | # Keep open_clip import for potential type hints or listing models, though not directly used for encoding now 3 | import open_clip 4 | import torch, os, time 5 | from tqdm import tqdm 6 | import random 7 | import argparse 8 | import numpy as np 9 | from PIL import Image 10 | 11 | import torch 12 | from torch.utils.data import Dataset, DataLoader 13 | import torch.multiprocessing as mp 14 | 15 | from utils.embedder import CustomImageDataset, CLIP_Encoder, PE_Encoder 16 | 17 | _DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' 18 | 19 | if 0: 20 | # Consider listing PE models too if available 21 | print("Pretrained open_clip models available:") 22 | try: 23 | options = open_clip.list_pretrained() 24 | for option in options: 25 | print(option) 26 | except Exception as e: 27 | print(f"Could not list open_clip models: {e}") 28 | # Add PE model listing here if possible 29 | # Example: print(pe.CLIP.available_configs()) # Need to import pe 30 | print("-----------------------------") 31 | 32 | 33 | # Rename CLIP_Feature_Dataset to Feature_Dataset 34 | class Feature_Dataset(): 35 | # Updated __init__ to handle different model types 36 | def __init__(self, root_dir, model_name, batch_size, 37 | model_path = None, 38 | force_reencode = False, 39 | shuffle_filenames = True, 40 | num_workers = 0, 41 | crop_names = ["centre_crop", "square_padded_crop", "subcrop1", "subcrop2"]): 42 | 43 | self.device = _DEVICE 44 | self.root_dir = root_dir 45 | self.model_name = model_name # Store the model name 46 | self.force_reencode = force_reencode 47 | self.img_extensions = (".png", ".jpg", ".jpeg", ".JPEG", ".JPG", ".PNG") 48 | self.batch_size = batch_size 49 | self.crop_names = crop_names 50 | 51 | # Find all images in root_dir: 52 | print("Searching images..") 53 | self.img_filepaths = [] 54 | for root, dirs, files in os.walk(root_dir): 55 | for name in files: 56 | if name.endswith(self.img_extensions): 57 | new_filename = os.path.join(root, name) 58 | self.img_filepaths.append(new_filename) 59 | 60 | if shuffle_filenames: 61 | random.shuffle(self.img_filepaths) 62 | else: # sort filenames: 63 | self.img_filepaths.sort() 64 | 65 | print(f"---> Found {len(self.img_filepaths)} images in {root_dir}") 66 | 67 | # Instantiate the correct encoder based on model name convention 68 | # Assuming PE models start with 'PE-' based on test.py 69 | if model_name.startswith("PE-"): 70 | self.encoder = PE_Encoder(model_name, device=self.device) 71 | # PE model path isn't handled by PE_Encoder yet, assuming download from HF 72 | elif '/' in model_name: # Assume open_clip format like 'ViT-L-14-336/openai' 73 | self.encoder = CLIP_Encoder(model_name, model_path, device=self.device) 74 | else: 75 | raise ValueError(f"Unknown model format: {model_name}. Expected 'PE-...' or 'Arch/Dataset'.") 76 | 77 | # Get the preprocessing transform from the encoder 78 | preprocess_transform = self.encoder.get_preprocess_transform() 79 | 80 | # Pass the transform to CustomImageDataset 81 | self.img_dataset = CustomImageDataset(self.img_filepaths, self.crop_names, preprocess_transform) 82 | dataloader_kwargs = { 83 | 'batch_size': batch_size, 84 | 'shuffle': False, 85 | 'num_workers': num_workers 86 | } 87 | if num_workers > 0: 88 | dataloader_kwargs['prefetch_factor'] = 2 # Default value, adjust if needed 89 | 90 | self.dataloader = DataLoader(self.img_dataset, **dataloader_kwargs) 91 | 92 | def __len__(self): 93 | return len(self.img_filepaths) 94 | 95 | @torch.no_grad() 96 | def process(self): 97 | n_embedded, n_skipped = 0, 0 98 | print(f"Embedding dataset of {len(self.img_filepaths)} images using {self.model_name}...") 99 | 100 | for batch_id, batch in enumerate(tqdm(self.dataloader)): 101 | # Batch now contains preprocessed crops directly from CustomImageDataset 102 | processed_crops, crop_names_batch, img_paths, img_feature_dict_batch = batch 103 | batch_size = processed_crops.shape[0] 104 | base_img_paths = [os.path.splitext(img_path)[0] for img_path in img_paths] 105 | feature_save_paths = [base_img_path + ".pt" for base_img_path in base_img_paths] 106 | # Adjust crop_names_batch structure if needed based on how CustomImageDataset returns it 107 | # Assuming it's now a list of lists [batch_size] x [n_crops] 108 | # crop_names_batch needs careful handling if not returned per-item by dataloader 109 | 110 | # Collapse batch and crop dimensions for encoding: 111 | # Input shape: [batch_size, n_crops, C, H, W] 112 | # Desired shape for encoder: [batch_size * n_crops, C, H, W] 113 | num_crops = processed_crops.shape[1] 114 | crops_stacked = processed_crops.view(-1, *processed_crops.shape[2:]) # [B*N_crops, C, H, W] 115 | crops_stacked = crops_stacked.to(self.device) # Ensure tensor is on the correct device 116 | 117 | # Check existence based on the specific model name 118 | existing_feature_save_paths = [p for p in feature_save_paths if os.path.exists(p)] 119 | already_encoded = 0 120 | for feature_save_path in existing_feature_save_paths: 121 | try: 122 | feature_dict = torch.load(feature_save_path, map_location='cpu') # Load to CPU first 123 | if self.model_name in feature_dict.keys(): # Check for the specific model key 124 | already_encoded += 1 125 | except Exception as e: 126 | print(f"Warning: Could not load existing feature file {feature_save_path}: {e}") 127 | 128 | if self.force_reencode or not already_encoded == batch_size: 129 | # Use the encoder instance to embed the stacked, preprocessed crops 130 | features = self.encoder.encode_image(crops_stacked) 131 | # Reshape features back to [batch_size, n_crops, dim] 132 | features = features.view(batch_size, num_crops, features.shape[-1]) 133 | 134 | # Save features (logic remains similar, but uses self.model_name as key) 135 | # Iterate through batch items, including the corresponding crop names list for each image 136 | for i, (feature_per_image, feature_save_path, img_path, current_crop_names) in enumerate(zip(features, feature_save_paths, img_paths, crop_names_batch)): 137 | # Load existing feature dict if it exists, otherwise create new 138 | final_feature_dict = {} 139 | if os.path.exists(feature_save_path) and not self.force_reencode: 140 | try: 141 | final_feature_dict = torch.load(feature_save_path, map_location='cpu') 142 | except Exception as e: 143 | print(f"Warning: Failed to load existing {feature_save_path} for update: {e}") 144 | 145 | # Extract per-crop features and store them 146 | feature_dict_for_model = {} 147 | 148 | # Load base image features (like HoG, FFT) - needs careful indexing 149 | # Assuming img_feature_dict_batch is structured correctly for the batch 150 | for img_feature_name in img_feature_dict_batch.keys(): 151 | # Ensure we are getting the correct item for the current image in the batch (index i) 152 | feature_dict_for_model[img_feature_name] = img_feature_dict_batch[img_feature_name][i] 153 | 154 | # Store features for each crop under its name 155 | # current_crop_names is now directly available from the loop 156 | for feature_crop, crop_name in zip(feature_per_image, current_crop_names): 157 | feature_dict_for_model[crop_name] = feature_crop.unsqueeze(0).cpu() # Store on CPU 158 | 159 | # Convert all tensors in the dict to float32 for consistency 160 | feature_dict_for_model = {k: v.float() if isinstance(v, torch.Tensor) else v 161 | for k, v in feature_dict_for_model.items()} 162 | 163 | # Add/update the features for the current model in the main dictionary 164 | final_feature_dict[self.model_name] = feature_dict_for_model 165 | 166 | # Save the updated dictionary 167 | try: 168 | torch.save(final_feature_dict, feature_save_path) 169 | except Exception as e: 170 | print(f"Error saving features to {feature_save_path}: {e}") 171 | 172 | n_embedded += batch_size 173 | else: 174 | n_skipped += batch_size 175 | 176 | if (n_embedded + n_skipped) > 0 and (n_embedded + n_skipped) % 1000 == 0: 177 | print(f"Processed {n_embedded + n_skipped} images. Skipped: {n_skipped}, Embedded: {n_embedded}") 178 | 179 | 180 | print("\n--- Feature encoding done! ---\n") 181 | print(f"Embedded {n_embedded} images ({n_skipped} images were already embedded). Features saved with model key '{self.model_name}'.") 182 | print(f"Feature vector dicts were saved alongside original images in {self.root_dir}") 183 | print(f"Crop names that were processed: {self.crop_names}") 184 | print("-----------------------------------------------\n\n") 185 | 186 | if __name__ == "__main__": 187 | parser = argparse.ArgumentParser() 188 | parser.add_argument('--root_dir', type=str, required=True, help='Root directory of the dataset (can contain subdirectories)') 189 | # Rename argument from clip_models_to_use to models_to_use 190 | parser.add_argument('--models_to_use', type=str, nargs='+', default=['ViT-L-14-336/openai'], 191 | help='Which CLIP (e.g., ViT-L-14-336/openai) or PE (e.g., PE-Core-B16-224) models to use') 192 | parser.add_argument('--batch_size', type=int, default=8, help='Number of images to encode at once') 193 | parser.add_argument('--num_workers', type=int, default=4, help='Number of workers for the dataloader') 194 | parser.add_argument('--force_reencode', action='store_true', help='Force re-encoding of all images for the specified models (default: False)') 195 | # Add model_path argument if needed for local CLIP models 196 | parser.add_argument('--model_path', type=str, default=None, help='Path to local directory for downloading/loading models (optional)') 197 | args = parser.parse_args() 198 | 199 | # Crop names remain the same 200 | crop_names = ['centre_crop', 'square_padded_crop', 'subcrop1', 'subcrop2'] 201 | 202 | mp.set_start_method('spawn') 203 | 204 | print(f"Embedding all imgs with {len(args.models_to_use)} models: \n--> {args.models_to_use}") 205 | 206 | # Loop through the specified models 207 | for model_name in args.models_to_use: 208 | print(f"\n--- Processing model: {model_name} ---") 209 | # Instantiate the renamed Feature_Dataset class 210 | dataset = Feature_Dataset(args.root_dir, model_name, args.batch_size, 211 | model_path = args.model_path, # Pass model_path 212 | force_reencode = args.force_reencode, 213 | num_workers = args.num_workers, 214 | crop_names = crop_names) 215 | dataset.process() -------------------------------------------------------------------------------- /_2_remove_duplicates.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import torch 4 | import argparse 5 | import random 6 | from tqdm import tqdm 7 | 8 | def get_paths_and_embeddings(args, crop_to_use, shuffle = False): 9 | 10 | for subdir, dirs, files in os.walk(args.root_dir): 11 | print(f"\nParsing {subdir}, subdirs: {dirs}, n_files: {len(files)}..") 12 | paths, embeddings = [], [] 13 | if shuffle: 14 | random.shuffle(files) # shuffle the list of files 15 | 16 | # Get all the unique filenames (without the extension) and store a list of present extensions for each one: 17 | unique_filenames = {} 18 | for file in files: 19 | filename, file_extension = os.path.splitext(file) 20 | if filename not in unique_filenames: 21 | unique_filenames[filename] = [] 22 | unique_filenames[filename].append(file_extension) 23 | 24 | print(f"Loading embeddings for {len(unique_filenames)} unique filenames..") 25 | for filename in tqdm(unique_filenames.keys()): 26 | extension_list = unique_filenames[filename] 27 | if '.jpg' in extension_list and '.pt' in extension_list: 28 | try: 29 | path = os.path.join(subdir, filename + '.jpg') 30 | embedding_dict = torch.load(os.path.join(subdir, filename + '.pt')) 31 | 32 | if args.clip_model_to_use is None: # use the first key in the embedding_dict: 33 | args.clip_model_to_use = list(embedding_dict.keys())[0] 34 | print(f"\n ----> args.clip_model_to_use was not specified, defaulting to first found one: {args.clip_model_to_use} \n") 35 | 36 | embedding_dict = embedding_dict[args.clip_model_to_use] 37 | 38 | embedding = embedding_dict[crop_to_use].squeeze().to(torch.float16) 39 | paths.append(path) 40 | embeddings.append(embedding) 41 | 42 | if len(paths) == args.chunk_size: 43 | yield paths, embeddings 44 | paths, embeddings = [], [] 45 | except: 46 | continue 47 | 48 | if len(paths) > 0: 49 | yield paths, embeddings 50 | 51 | 52 | def find_near_duplicates(args, 53 | sim_type = 'cosine', # ['cosine', 'euclidean'] 54 | crop_to_use = 'square_padded_crop', # which crop CLIP embedding do we compute the similarity on? 55 | ): 56 | 57 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 58 | 59 | for paths, embeddings in get_paths_and_embeddings(args, crop_to_use): 60 | if len(paths) == 0 or len(embeddings) == 0: 61 | continue 62 | 63 | embeddings = torch.stack(embeddings).to(device) 64 | 65 | # Compute the similarity matrix: 66 | print(f"Got first batch of embeddings of shape: {embeddings.shape}, computing similarity matrix..") 67 | normalized_embeddings = embeddings / torch.norm(embeddings, dim=1, keepdim=True) 68 | if sim_type == 'cosine': 69 | similarity_matrix = torch.matmul(normalized_embeddings, normalized_embeddings.T) 70 | elif sim_type == 'euclidean': 71 | similarity_matrix = torch.cdist(normalized_embeddings, normalized_embeddings) 72 | 73 | # Find the near duplicates using torch.where: 74 | near_duplicate_indices = torch.where(torch.triu(similarity_matrix, diagonal=1) > args.threshold) 75 | # convert indices to lists of integers: 76 | near_duplicate_indices = list(zip(near_duplicate_indices[0].tolist(), near_duplicate_indices[1].tolist())) 77 | near_duplicates = [(paths[i], paths[j]) for i, j in near_duplicate_indices] 78 | 79 | # Get the actual similarity_value for each duplicate pair: 80 | near_duplicate_values = [similarity_matrix[i, j].item() for i, j in near_duplicate_indices] 81 | 82 | # Create a folder for the near duplicates next to the root_dir: 83 | output_dir = os.path.join(os.path.dirname(args.root_dir), f"near_duplicates_{sim_type}_{args.threshold}") 84 | os.makedirs(output_dir, exist_ok=True) 85 | 86 | i = 0 87 | print(f"Found {len(near_duplicates)} duplicates!") 88 | 89 | if len(near_duplicates) > 0 and not args.test: 90 | verb = "copying" if args.mode == 'copy' else "moving" 91 | print(f"{verb} {len(near_duplicates)} near duplicates to {output_dir}...") 92 | 93 | for i, (img_paths, sim_value) in enumerate(zip(near_duplicates, near_duplicate_values)): 94 | fix_duplicate(i, img_paths, output_dir, sim_value, args.mode) 95 | 96 | if args.mode == 'move': 97 | print(f"Moved {i} duplicates to {output_dir}") 98 | elif args.mode == 'copy': 99 | print(f"Copied {i} duplicates (not removed from data yet!) to {output_dir}") 100 | 101 | 102 | def fix_duplicate(duplicate_index, img_paths, outdir, sim_value, mode): 103 | # TODO: Remove the duplicate with the lowest predicted aesthetic score 104 | 105 | dirname = os.path.dirname(img_paths[0]) 106 | # get the two basenames without extensions: 107 | basename1 = os.path.splitext(os.path.basename(img_paths[0]))[0] 108 | basename2 = os.path.splitext(os.path.basename(img_paths[1]))[0] 109 | 110 | # find all files with this same basename: 111 | files1 = [os.path.join(dirname, f) for f in os.listdir(os.path.dirname(img_paths[0])) if basename1 in f] 112 | files2 = [os.path.join(dirname, f) for f in os.listdir(os.path.dirname(img_paths[1])) if basename2 in f] 113 | 114 | # copy all files to the output directory: 115 | for f in files1: 116 | if mode == 'copy': 117 | shutil.copy(f, os.path.join(outdir, f"{sim_value:.3f}_{duplicate_index:08d}_source_{os.path.basename(f)}")) 118 | 119 | for f in files2: 120 | if mode == 'copy': 121 | shutil.copy(f, os.path.join(outdir, f"{sim_value:.3f}_{duplicate_index:08d}_target_{os.path.basename(f)}")) 122 | if mode == 'move': 123 | os.rename(f, os.path.join(outdir, f"{sim_value:.3f}_{duplicate_index:08d}_target_{os.path.basename(f)}")) 124 | 125 | return 126 | 127 | 128 | if __name__ == '__main__': 129 | 130 | """ 131 | Scan for near duplicate imgs using CILP embeddings 132 | and copy / move those to a new folder 133 | """ 134 | 135 | parser = argparse.ArgumentParser() 136 | parser.add_argument('--root_dir', type=str, help='Root directory of the dataset') 137 | parser.add_argument('--threshold', type=float, default=0.96, help='Cosine-similarity threshold for near-duplicate detection') 138 | parser.add_argument('--mode', type=str, default='copy', help='copy / move, Use copy to test the script, move after') 139 | parser.add_argument('--clip_model_to_use', type=str, default=None, help='Which CLIP model to use, if None, use the first one found') 140 | parser.add_argument('--chunk_size', type=int, default=10000, help='Chunk the duplicate detection into batches of this size to avoid OOM') 141 | parser.add_argument('--test', action='store_true', help='Test the script without doing anything') 142 | args = parser.parse_args() 143 | 144 | find_near_duplicates(args) 145 | -------------------------------------------------------------------------------- /_3_label_images.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch 3 | from tqdm import tqdm 4 | import os 5 | import glob 6 | import time 7 | import random 8 | import argparse 9 | import pandas as pd 10 | from PIL import Image 11 | import numpy as np 12 | import tkinter as tk 13 | from tkinter import ttk 14 | import shutil 15 | 16 | from natsort import natsorted, ns 17 | def nautilus_sort(filenames): 18 | # Sort filenames naturally and case-insensitively 19 | return natsorted(filenames, alg=ns.IGNORECASE) 20 | 21 | def create_backup(database_path): 22 | folder = os.path.dirname(database_path) 23 | files = glob.glob(folder + "/*") 24 | 25 | # If a backup already exists, delete it: 26 | for file in files: 27 | if "_db_backup_" in file: 28 | os.remove(file) 29 | 30 | # Create a backup of the database: 31 | timestamp = time.strftime("%Y%m%d-%H%M%S") 32 | new_backup_path = database_path.replace(".csv", f"_db_backup_{timestamp}.csv") 33 | shutil.copy(database_path, new_backup_path) 34 | print("Created a backup of the database at ", new_backup_path) 35 | 36 | 37 | 38 | def create_sorting_window(): 39 | def on_closing(): 40 | sorting_window.quit() 41 | 42 | def on_sort_button_click(): 43 | global selected_option 44 | selected_option = sorting_var.get() 45 | on_closing() 46 | 47 | sorting_window = tk.Tk() 48 | sorting_window.protocol("WM_DELETE_WINDOW", on_closing) 49 | sorting_window.title("Sort Options") 50 | 51 | sorting_var = tk.StringVar() 52 | sorting_var.set("uuid") 53 | 54 | radio1 = ttk.Radiobutton( 55 | sorting_window, text="UUID", variable=sorting_var, value="uuid" 56 | ) 57 | radio2 = ttk.Radiobutton( 58 | sorting_window, 59 | text="Predicted bad first", 60 | variable=sorting_var, 61 | value="Predicted bad first", 62 | ) 63 | radio3 = ttk.Radiobutton( 64 | sorting_window, 65 | text="Predicted good first", 66 | variable=sorting_var, 67 | value="Predicted good first", 68 | ) 69 | radio4 = ttk.Radiobutton( 70 | sorting_window, 71 | text="middle first", 72 | variable=sorting_var, 73 | value="middle", 74 | ) 75 | radio5 = ttk.Radiobutton( 76 | sorting_window, 77 | text="diversity sorted", 78 | variable=sorting_var, 79 | value="diversity", 80 | ) 81 | 82 | sort_button = ttk.Button(sorting_window, text="Sort", command=on_sort_button_click) 83 | 84 | radio1.grid(row=0, column=0, padx=10, pady=10) 85 | radio2.grid(row=1, column=0, padx=10, pady=10) 86 | radio3.grid(row=2, column=0, padx=10, pady=10) 87 | radio4.grid(row=3, column=0, padx=10, pady=10) 88 | radio5.grid(row=4, column=0, padx=10, pady=10) 89 | sort_button.grid(row=5, column=0, padx=10, pady=10) 90 | 91 | sorting_window.mainloop() 92 | return selected_option 93 | 94 | 95 | 96 | def resize(cv_img, size = (1706, 960)): 97 | canvas = Image.new('RGB', size, (0, 0, 0)) 98 | 99 | # Resize the image so it fits on the canvas: 100 | height, width, _ = cv_img.shape 101 | ratio = min(size[0] / width, size[1] / height) 102 | 103 | cv_img = cv2.resize(cv_img, (int(width * ratio), int(height * ratio))) 104 | 105 | # paste the image onto the canvas: 106 | height, width, _ = cv_img.shape 107 | canvas.paste(Image.fromarray(cv_img), (int((size[0] - width) / 2), int((size[1] - height) / 2))) 108 | 109 | return np.array(canvas) 110 | 111 | 112 | def relabel_image(uuid, label, database): 113 | current_timestamp = int(time.time()) 114 | row = database.loc[database["uuid"] == uuid] 115 | if row is None or len(row) == 0: 116 | # Create a new entry in the database: 117 | new_row = {"uuid": uuid, "label": label, "timestamp": current_timestamp} 118 | database = pd.concat([database, pd.DataFrame([new_row])], ignore_index=True) 119 | else: 120 | # Update the existing entry: 121 | index_to_update = database.loc[database['uuid'] == uuid].index[0] 122 | # Update the values in the row 123 | database.loc[index_to_update, 'label'] = label 124 | database.loc[index_to_update, 'timestamp'] = current_timestamp 125 | 126 | return database 127 | 128 | @torch.no_grad() 129 | def cosine_similarity_matrix(a, b): 130 | a_norm = a / a.norm(dim=1, keepdim=True) 131 | b_norm = b / b.norm(dim=1, keepdim=True) 132 | return torch.matmul(a_norm, b_norm.t()) 133 | 134 | @torch.no_grad() 135 | def diversity_ordered_image_files(image_files, root_directory, total_n_ordered_imgs = 500, sample_size = 100): 136 | """ 137 | Tries to order the first total_n_ordered_imgs in a way that maximizes the diversity of that set in CLIP space. 138 | This is idea for starting a fresh labeling session, where you want to label the most diverse images first. 139 | 140 | """ 141 | img_files = [image_files[0]] 142 | img_embedding = torch.load(os.path.join(root_directory, os.path.basename(img_files[0]).replace(".jpg", ".pt")))['square_padded_crop'] 143 | img_embedding = img_embedding.squeeze().unsqueeze(0) 144 | 145 | print("Creating the most CLIP-diverse ordering of the first ", total_n_ordered_imgs, " images...") 146 | 147 | for i in tqdm(range(min(total_n_ordered_imgs, len(image_files)-1))): 148 | # sample a random subset of the image files: 149 | sample = random.sample(image_files, sample_size) 150 | 151 | # get the corresponding .pt file for each: 152 | sample_pt_files = [os.path.join(root_directory, os.path.basename(f).replace(".jpg", ".pt")) for f in sample] 153 | 154 | # load the "square_padded_crop" CLIP embedding for each: 155 | sample_embeddings = [torch.load(f)['square_padded_crop'] for f in sample_pt_files] 156 | sample_embeddings = torch.stack(sample_embeddings, dim=0).squeeze() 157 | 158 | # compute the similarities between all current image embeddings and the embeddings of the sample: 159 | similarities = cosine_similarity_matrix(img_embedding, sample_embeddings) 160 | 161 | # Find the maximum similarity value for each sample (the current embedding it is closest to) 162 | max_val, _ = torch.max(similarities, dim=0) 163 | 164 | # Find the index of the sample with the smallest maximum similarity 165 | index_of_min = torch.argmin(max_val).item() 166 | 167 | # add the image with the lowest similarity to the ordered list: 168 | img_files.append(sample[index_of_min]) 169 | embedding_to_add = sample_embeddings[index_of_min].unsqueeze(0) 170 | 171 | # aappend the embedding of the image with the lowest similarity to the current embedding: 172 | img_embedding = torch.cat((img_embedding, embedding_to_add), dim=0) 173 | 174 | # add the remaining images to the ordered list: 175 | img_files = img_files + [f for f in image_files if f not in img_files] 176 | 177 | return img_files 178 | 179 | 180 | def re_order_images(image_files, database, root_directory): 181 | ''' 182 | Takes the pandas dataframe database and sorts the image files according to the "predicted_label" column. 183 | ''' 184 | sorting_option = create_sorting_window() 185 | 186 | if sorting_option == "uuid": 187 | return image_files 188 | 189 | elif sorting_option == "diversity": 190 | return diversity_ordered_image_files(image_files, root_directory) 191 | 192 | else: 193 | # Modify the image_files sorting according to the selected option 194 | if sorting_option == "Predicted bad first": 195 | sorted_indices = database['predicted_label'].argsort().values 196 | 197 | elif sorting_option == "Predicted good first": 198 | sorted_indices = database['predicted_label'].argsort().values[::-1] 199 | 200 | elif sorting_option == "middle": 201 | # Get the median value of the predicted labels: 202 | median = database['predicted_label'].median() 203 | # Get the distance of each predicted label from the median: 204 | database['distance_from_median'] = abs(database['predicted_label'] - median) 205 | # Sort the database by the distance from the median: 206 | sorted_indices = database['distance_from_median'].argsort().values 207 | 208 | # get the uuids of those rows in the database: 209 | uuids = database['uuid'].values[sorted_indices] 210 | # get the image files that correspond to those uuids: 211 | possible_image_files = [os.path.join(root_directory, uuid + ".jpg") for uuid in uuids] 212 | 213 | return [f for f in possible_image_files if f in image_files] 214 | 215 | def is_already_labeled(label): 216 | return (label != "") and (label is not None) and (not np.isnan(label)) 217 | 218 | def print_label_info(database, columns = ["uuid", "label", "predicted_label"]): 219 | n_labeled = sum(map(is_already_labeled, database['label'])) 220 | print(f"{n_labeled} of {len(database)} images in the database labeled") 221 | 222 | def draw_progress_bar(image, progress, total, height=10, color=(0, 255, 0), thickness=-1): 223 | rows, cols, _ = image.shape 224 | 225 | progress_bar_width = int(cols * 0.8) 226 | progress_bar_start_x = int(cols * 0.1) 227 | progress_bar_end_x = progress_bar_start_x + progress_bar_width 228 | progress_bar_y = rows - height 229 | 230 | cv2.rectangle(image, (progress_bar_start_x, progress_bar_y), (progress_bar_end_x, rows), (255, 255, 255), thickness) 231 | 232 | progress_width = int((progress / total) * progress_bar_width) 233 | cv2.rectangle(image, (progress_bar_start_x, progress_bar_y), (progress_bar_start_x + progress_width, rows), color, thickness) 234 | 235 | 236 | def fix_database(database): 237 | # Loop over all rows of the dataframe 238 | # When a row has the "label" column filled in, copy that value to the predicted_label column: 239 | for index, row in database.iterrows(): 240 | if is_already_labeled(row['label']): 241 | database.loc[index, 'predicted_label'] = row['label'] 242 | 243 | return database 244 | 245 | import json 246 | def load_image_and_prompt(uuid, root_directory): 247 | image_filepath = os.path.join(root_directory, uuid + ".jpg") 248 | image = cv2.imread(image_filepath) 249 | prompt = '' 250 | 251 | 252 | txt_filepath = os.path.join(root_directory, uuid + ".txt") 253 | if os.path.exists(txt_filepath): 254 | for line in open(txt_filepath, "r"): 255 | prompt = line 256 | 257 | json_filepath = os.path.join(root_directory, uuid + ".json") 258 | if os.path.exists(json_filepath): 259 | with open(json_filepath, "r") as f: 260 | json_data = json.load(f) 261 | try: 262 | prompt = json_data['text_input'] 263 | except: 264 | prompt = "" 265 | 266 | return image, prompt 267 | 268 | def load(uuid, database): 269 | # Check if this uuid is already in the database: 270 | row = database.loc[database["uuid"] == uuid] 271 | 272 | if row is None or len(row) == 0: 273 | return None 274 | else: 275 | return row["label"].values[0] 276 | 277 | def label_dataset(root_directory, skip_labeled_files = True): 278 | label_file = os.path.join(os.path.dirname(root_directory), os.path.basename(root_directory) + ".csv") 279 | image_files = nautilus_sort(glob.glob(os.path.join(root_directory, "**/*.jpg"), recursive=True)) 280 | 281 | if os.path.exists(label_file): 282 | database = pd.read_csv(label_file) 283 | create_backup(label_file) 284 | else: 285 | database = pd.DataFrame(columns=["uuid", "label", "timestamp", "predicted_label"]) 286 | 287 | # count how many rows have the label column filled in: 288 | labeled_count = len(database.loc[database["label"].notnull()]) 289 | print(f"Found {labeled_count} labeled images ({len(image_files)} total) in {label_file}") 290 | 291 | database = fix_database(database) 292 | image_files = re_order_images(image_files, database, root_directory) 293 | current_index = 0 294 | extra_labels = 0 295 | 296 | while True: 297 | image_file = image_files[current_index] 298 | uuid = os.path.splitext(os.path.basename(image_file))[0] 299 | label = load(uuid, database) 300 | if (label is not None) and (not np.isnan(label)) and skip_labeled_files: 301 | current_index += 1 302 | continue 303 | 304 | skip_labeled_files = False 305 | image, prompt = load_image_and_prompt(uuid, root_directory) 306 | image = resize(image) 307 | 308 | if label is not None and not np.isnan(label): 309 | cv2.putText(image, f"{label:.2f} || {prompt}", (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (200, 100, 25), 2) 310 | else: 311 | try: 312 | # Get the predicted label from the database: 313 | predicted_label = database.loc[database["uuid"] == uuid, "predicted_label"].values[0] 314 | cv2.putText(image, f"predicted: {predicted_label:.3f} || {prompt}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (200, 100, 25), 2) 315 | except: 316 | cv2.putText(image, f"{prompt}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (200, 100, 25), 2) 317 | 318 | draw_progress_bar(image, current_index, len(image_files)) 319 | cv2.namedWindow("image", cv2.WINDOW_AUTOSIZE) # Set the window property to autosize 320 | cv2.imshow("image", image) # Display the image in the "image" window 321 | key = cv2.waitKey(0) 322 | 323 | if ord('0') <= key <= ord('9'): 324 | label = (key - ord('0')) / 10.0 325 | database = relabel_image(uuid, label, database) 326 | current_index += 1 327 | extra_labels += 1 328 | 329 | if extra_labels % 5 == 0: 330 | database.to_csv(label_file, index=False) 331 | print_label_info(database) 332 | 333 | elif key == ord('q') or key == 27: # 'q' or 'esc' key 334 | break 335 | elif key == 81: # left arrow key 336 | current_index -= 1 337 | elif key == 83: # right arrow key 338 | current_index += 1 339 | 340 | current_index = current_index % len(image_files) 341 | 342 | cv2.destroyAllWindows() 343 | database.to_csv(label_file, index=False) 344 | print_label_info(database) 345 | 346 | 347 | if __name__ == "__main__": 348 | 349 | """ 350 | Fire up a very basic opencv labeling interface 351 | """ 352 | 353 | parser = argparse.ArgumentParser() 354 | parser.add_argument('--root_dir', type=str, help='Root directory of the dataset') 355 | parser.add_argument('--skip_labeled_files', action='store_true', help='Skip files that are already labeled') 356 | args = parser.parse_args() 357 | 358 | label_dataset(args.root_dir, args.skip_labeled_files) -------------------------------------------------------------------------------- /_4_train_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pandas as pd 4 | import argparse 5 | import pickle 6 | from tqdm import tqdm 7 | import torch 8 | from torch import nn 9 | from torch.utils.data import DataLoader, Dataset, random_split 10 | from torch.optim import Adam 11 | from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts 12 | import matplotlib.pyplot as plt 13 | from utils.nn_model import device, SimpleFC 14 | from sklearn.metrics import r2_score 15 | 16 | def train(args, crop_names, use_img_stat_features): 17 | 18 | torch.manual_seed(args.random_seed) 19 | np.random.seed(args.random_seed) 20 | 21 | features = [] 22 | labels = [] 23 | 24 | if args.clip_models_to_use[0] != "all": 25 | print(f"\n----> Using clip models: {args.clip_models_to_use}") 26 | 27 | # Load all the labeled training data from disk: 28 | for train_data_name in args.train_data_names: 29 | n_samples = 0 30 | skips = 0 31 | 32 | # Load the labels and uuid's from labels.csv 33 | data = pd.read_csv(os.path.join(args.train_data_dir, train_data_name + '.csv')) 34 | # Drop all the rows where "label" is NaN: 35 | data = data.dropna(subset=["label"]) 36 | # randomly shuffle the data: 37 | data = data.sample(frac=1).reset_index(drop=True) 38 | 39 | # Load the feature vectors from disk (uuid.pt) 40 | print(f"\nLoading {train_data_name} features from disk...") 41 | 42 | for index, row in tqdm(data.iterrows()): 43 | try: 44 | uuid = row["uuid"] 45 | label = row["label"] 46 | full_feature_dict = torch.load(f"{args.train_data_dir}/{train_data_name}/{uuid}.pt") 47 | 48 | if args.clip_models_to_use[0] == "all": 49 | args.clip_models_to_use = list(full_feature_dict.keys()) 50 | print(f"\n----> Using all found clip models: {args.clip_models_to_use}") 51 | 52 | sample_features = [] 53 | 54 | for clip_model_name in args.clip_models_to_use: 55 | feature_dict = full_feature_dict[clip_model_name] 56 | clip_features = torch.cat([feature_dict[crop_name] for crop_name in crop_names if crop_name in feature_dict], dim=0).flatten() 57 | missing_crops = set(crop_names) - set(feature_dict.keys()) 58 | if missing_crops: 59 | raise Exception(f"Missing crops {missing_crops} for {uuid}, either re-embed the image, or adjust the crop_names variable for training!") 60 | 61 | if use_img_stat_features: 62 | img_stat_feature_names = [key for key in feature_dict.keys() if key.startswith("img_stat_")] 63 | img_stat_features = torch.stack([feature_dict[img_stat_feature_name] for img_stat_feature_name in img_stat_feature_names], dim=0).to(device) 64 | all_features = torch.cat([clip_features, img_stat_features], dim=0) 65 | else: 66 | all_features = clip_features 67 | 68 | sample_features.append(all_features) 69 | 70 | features.append(torch.cat(sample_features, dim=0)) 71 | labels.append(label) 72 | n_samples += 1 73 | except Exception as e: # simply skip the sample if something goes wrong 74 | skips += 1 75 | continue 76 | 77 | print(f"Loaded {n_samples} samples from {train_data_name}!") 78 | if skips > 0: 79 | print(f"(skipped {skips} samples due to loading errors)..") 80 | 81 | features = torch.stack(features, dim=0).to(device).float() 82 | labels = torch.tensor(labels).to(device).float() 83 | 84 | # Map the labels to 0-1: 85 | print("Normalizing labels to [0,1]...") 86 | print(f"min: {labels.min()}, max: {labels.max()}") 87 | labels_min, labels_max = labels.min(), labels.max() 88 | labels = (labels - labels_min) / (labels_max - labels_min) 89 | 90 | print("\n--- All data loaded ---") 91 | print("Features shape:", features.shape) 92 | print("Labels shape:", labels.shape) 93 | 94 | # 2. Create train and test dataloaders 95 | class RegressionDataset(Dataset): 96 | def __init__(self, features, labels): 97 | self.features = features 98 | self.labels = labels 99 | 100 | def __len__(self): 101 | return len(self.features) 102 | 103 | def __getitem__(self, idx): 104 | return self.features[idx], self.labels[idx] 105 | 106 | dataset = RegressionDataset(features, labels) 107 | train_size = int((1-args.test_fraction) * len(dataset)) 108 | test_size = len(dataset) - train_size 109 | 110 | print(f"Training on {train_size} samples, testing on {test_size} samples.") 111 | 112 | train_dataset, test_dataset = random_split(dataset, [train_size, test_size]) 113 | train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True) 114 | test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False) 115 | 116 | # 3. Create the regression network: 117 | model = SimpleFC(features.shape[1], args.hidden_sizes, 1, args.clip_models_to_use, 118 | crop_names = crop_names, 119 | dropout_prob = args.dropout_prob, 120 | verbose = args.print_network_layout) 121 | model.train() 122 | model.to(device) 123 | 124 | # 4. Train the network using Adam optimizer with cosine learning rate scheduler 125 | optimizer = Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 126 | scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=args.restart_epochs, T_mult=1, eta_min=args.min_lr) 127 | criterion = nn.MSELoss() 128 | losses = [[], []] # train, test losses 129 | lrs = [] # learning rates 130 | 131 | def get_test_loss(model, test_loader, epoch, plot_correlation=1): 132 | if len(test_loader) == 0: 133 | return -1.0, -1.0 134 | model.eval() 135 | test_loss, dummy_test_loss = 0.0, 0.0 136 | test_preds, test_labels = [], [] 137 | with torch.no_grad(): 138 | for features, labels in test_loader: 139 | outputs = model(features) 140 | loss = criterion(outputs.squeeze(), labels) 141 | test_loss += loss.item() 142 | 143 | dummy_outputs = torch.ones_like(outputs) * labels.mean() 144 | dummy_loss = criterion(dummy_outputs.squeeze(), labels) 145 | dummy_test_loss += dummy_loss.item() 146 | 147 | if plot_correlation: 148 | test_preds.append(outputs.cpu().numpy()) 149 | test_labels.append(labels.cpu().numpy()) 150 | 151 | if plot_correlation and epoch % 5 == 0: 152 | test_preds = np.concatenate(test_preds, axis=0) 153 | test_labels = np.concatenate(test_labels, axis=0) 154 | plt.figure(figsize=(8, 8)) 155 | plt.scatter(test_labels, test_preds, alpha=0.1) 156 | plt.xlabel("True labels") 157 | plt.ylabel("Predicted labels") 158 | plt.plot([0, 1], [0, 1], color='r', linestyle='--') 159 | plt.title(f"Epoch {epoch}, r² = {r2_score(test_labels, test_preds):.3f}") 160 | plt.xlim(0, 1) 161 | plt.ylim(0, 1) 162 | plt.savefig("test_set_predictions.png") 163 | plt.close() 164 | 165 | test_loss /= len(test_loader) 166 | dummy_test_loss /= len(test_loader) 167 | model.train() 168 | return test_loss, dummy_test_loss 169 | 170 | def plot_losses(losses, lrs, y_axis_percentile_cutoff=99.75, include_y_zero=1): 171 | # Plot losses 172 | plt.figure(figsize=(16, 8)) 173 | plt.subplot(1, 2, 1) 174 | plt.plot(losses[0], label="Train") 175 | plt.plot(losses[1], label="Test") 176 | plt.axhline(y=min(losses[1]), color='r', linestyle='--', label="Best test loss") 177 | all_losses = losses[0] + losses[1] 178 | if include_y_zero: 179 | plt.ylim(0, np.percentile(all_losses, y_axis_percentile_cutoff)) 180 | else: 181 | plt.ylim(np.min(all_losses), np.percentile(all_losses, y_axis_percentile_cutoff)) 182 | plt.xlabel("Epoch") 183 | plt.ylabel("MSE loss") 184 | plt.legend() 185 | 186 | # Plot learning rate 187 | plt.subplot(1, 2, 2) 188 | plt.plot(lrs, label="Learning Rate") 189 | plt.xlabel("Epoch") 190 | plt.ylabel("Learning Rate") 191 | plt.legend() 192 | 193 | plt.tight_layout() 194 | plt.savefig("training_progress.png") 195 | plt.close() 196 | 197 | test_loss, dummy_test_loss = get_test_loss(model, test_loader, -1) 198 | print(f"\nBefore training, test mse-loss: {test_loss:.4f} (dummy: {dummy_test_loss:.4f})") 199 | 200 | for epoch in range(args.n_epochs): 201 | model.train() 202 | train_loss = 0.0 203 | for features, labels in train_loader: 204 | optimizer.zero_grad() 205 | outputs = model(features) 206 | loss = criterion(outputs.squeeze(), labels) 207 | loss.backward() 208 | optimizer.step() 209 | train_loss += loss.item() 210 | 211 | # Step the scheduler 212 | scheduler.step() 213 | current_lr = scheduler.get_last_lr()[0] 214 | lrs.append(current_lr) 215 | 216 | train_loss = train_loss / len(train_loader) 217 | test_loss, dummy_test_loss = get_test_loss(model, test_loader, epoch) 218 | losses[0].append(train_loss) 219 | losses[1].append(test_loss) 220 | if epoch % 2 == 0: 221 | test_str = f", test mse: {test_loss:.4f} (dummy: {dummy_test_loss:.4f})" if test_loss > 0 else "" 222 | print(f"Epoch {epoch+1}/{args.n_epochs}, train-mse: {train_loss:.4f}, lr: {current_lr:.6f}{test_str}") 223 | if epoch % (args.n_epochs // 10) == 0: 224 | plot_losses(losses, lrs) 225 | 226 | # Report: 227 | if test_loss > 0: 228 | print(f"---> Best test mse loss: {min(losses[1]):.4f} in epoch {np.argmin(losses[1])+1}") 229 | plot_losses(losses, lrs) 230 | 231 | if not args.dont_save: # Save the model 232 | model.eval() 233 | timestamp = pd.Timestamp.now().strftime("%Y-%m-%d_%H:%M:%S") 234 | model_save_name = f"{args.model_name}_{timestamp}_{(len(train_dataset) / 1000):.1f}k_imgs_{args.n_epochs}_epochs_{losses[1][-1]:.4f}_mse" 235 | os.makedirs("models", exist_ok=True) 236 | 237 | torch.save(model, f"models/{model_save_name}.pth") 238 | print("Final model saved to /model dir as:\n", f"{model_save_name}.pth") 239 | 240 | if __name__ == "__main__": 241 | parser = argparse.ArgumentParser() 242 | 243 | # IO args: 244 | parser.add_argument('--train_data_dir', type=str, help='Root directory of the (optionally multiple) datasets') 245 | parser.add_argument('--train_data_names', type=str, nargs='+', help='Names of the dataset files to train on (space separated)') 246 | parser.add_argument('--model_name', type=str, default='regressor', help='Name of the model when saved to disk') 247 | parser.add_argument('--dont_save', action='store_true', help='skip saving the model to disk') 248 | 249 | # Training args: 250 | parser.add_argument('--clip_models_to_use', metavar='S', type=str, nargs='+', default=['all'], help='Which CLIP model embeddings to use, default: use all found') 251 | parser.add_argument('--test_fraction', type=float, default=0.25, help='Fraction of the training data to use for testing') 252 | parser.add_argument('--n_epochs', type=int, default=60, help='Number of epochs to train for') 253 | parser.add_argument('--batch_size', type=int, default=16, help='Batch size for training') 254 | parser.add_argument('--lr', type=float, default=0.0002, help='Initial learning rate') 255 | parser.add_argument('--min_lr', type=float, default=1e-6, help='Minimum learning rate for cosine scheduler') 256 | parser.add_argument('--restart_epochs', type=int, default=10, help='Number of epochs before learning rate restart') 257 | parser.add_argument('--weight_decay', type=float, default=0.0006, help='Weight decay for the Adam optimizer') 258 | parser.add_argument('--dropout_prob', type=float, default=0.5, help='Dropout probability') 259 | parser.add_argument('--hidden_sizes', type=int, nargs='+', default=[264,128,64], help='Hidden sizes of the FC neural network') 260 | 261 | parser.add_argument('--print_network_layout', action='store_true', help='Print the network layout') 262 | parser.add_argument('--random_seed', type=int, default=42, help='Random seed for reproducibility') 263 | args = parser.parse_args() 264 | 265 | # Custom switches to turn on/off certain features: 266 | crop_names = ['centre_crop', 'square_padded_crop', 'subcrop1_0.15', 'subcrop2_0.1'] # 0.265 267 | crop_names = ['centre_crop', 'subcrop2_0.1'] # 0.27 268 | #crop_names = ['square_padded_crop', 'subcrop2_0.1'] # 0.275 269 | #crop_names = ['centre_crop'] # 0.285 270 | #crop_names = ['square_padded_crop'] # 0.29 271 | #crop_names = ['subcrop1_0.15'] # 0.30 272 | #crop_names = ['subcrop2_0.1'] # 0.31 273 | 274 | use_img_stat_features = 0 275 | 276 | train(args, crop_names, use_img_stat_features) -------------------------------------------------------------------------------- /_5_predict_labels.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pandas as pd 4 | import torch, shutil 5 | import pickle, time 6 | import random 7 | from tqdm import tqdm 8 | import argparse 9 | from matplotlib import pyplot as plt 10 | import json 11 | from utils.nn_model import device, SimpleFC 12 | import torch 13 | from torch.utils.data import Dataset, DataLoader 14 | import torch.multiprocessing as mp 15 | 16 | def plot_label_distribution(database, args, max_x = 1.0): 17 | # Save a plot of the label distribution 18 | fig, ax = plt.subplots(figsize=(10, 6)) 19 | 20 | # Create the histogram 21 | n, bins, patches = ax.hist(database['predicted_label'].values, bins=100, alpha=0.75, color='blue', edgecolor='black') 22 | 23 | # Customize the plot appearance 24 | ax.set_title(f'Label Distribution for {os.path.basename(args.root_dir)}', fontsize=18) 25 | ax.set_xlabel('Predicted Label', fontsize=14) 26 | ax.set_ylabel('Frequency', fontsize=14) 27 | ax.grid(axis='y', alpha=0.75, linestyle='--') 28 | 29 | # Add a text box with mean and standard deviation 30 | mu = np.mean(database['predicted_label'].values) 31 | sigma = np.std(database['predicted_label'].values) 32 | textstr = f'$\mu={mu:.2f}$\n$\sigma={sigma:.2f}$' 33 | props = dict(boxstyle='round', facecolor='white', alpha=0.8) 34 | ax.text(0.05, 0.95, textstr, transform=ax.transAxes, fontsize=12, 35 | verticalalignment='top', bbox=props) 36 | 37 | # Set a custom y-axis tick format 38 | ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda x, loc: f"{int(x):,}")) 39 | 40 | # Set the x-axis range 41 | ax.set_xlim(left=0, right=max_x) 42 | 43 | # Save the plot in the parent dir of args.root_dir 44 | output_dir = os.path.dirname(args.root_dir) 45 | plt.savefig(os.path.join(output_dir, f"label_distribution_{os.path.basename(args.root_dir)}.png")) 46 | plt.close() 47 | 48 | def find_model(model_name, model_dir = "models"): 49 | if os.path.exists(model_name) and os.path.isfile(model_name) and model_name.endswith(".pkl"): 50 | return model_name 51 | model_files = os.listdir(model_dir) 52 | for model_file in model_files: 53 | if model_name in model_file: 54 | # return the absolute path to the model file: 55 | return os.path.join(model_dir, model_file) 56 | return None 57 | 58 | 59 | class CustomDataset(Dataset): 60 | def __init__(self, uuids, model, args): 61 | self.uuids = uuids 62 | self.model = model # prediction model 63 | self.args = args 64 | self.feature_shape = None 65 | 66 | def __len__(self): 67 | return len(self.uuids) 68 | 69 | def __getitem__(self, idx): 70 | uuid = self.uuids[idx] 71 | img_path = os.path.join(self.args.root_dir, uuid + '.jpg') 72 | feature_path = os.path.join(self.args.root_dir, uuid + '.pt') 73 | 74 | try: 75 | full_feature_dict = torch.load(feature_path) 76 | sample_features = [] 77 | for clip_model_name in self.args.clip_models: 78 | feature_dict = full_feature_dict[clip_model_name] 79 | clip_features = torch.cat([feature_dict[crop_name] for crop_name in self.model.crop_names if crop_name in feature_dict], dim=0).flatten() 80 | sample_features.append(clip_features) 81 | 82 | img_features = torch.cat(sample_features, dim=0).flatten() 83 | self.feature_shape = img_features.shape 84 | except Exception as e: 85 | print(f"WARNING: {str(e)} for {uuid}, skipping this sample..") 86 | return "", "", torch.zeros(self.feature_shape, device=device) 87 | 88 | return uuid, img_path, img_features 89 | 90 | @torch.no_grad() 91 | def predict_labels(args): 92 | 93 | model_file = find_model(args.model_file) 94 | if model_file is None: 95 | print(f"ERROR: could not find model file {model_file}!") 96 | exit() 97 | 98 | output_dir = args.root_dir + '_predicted_scores' 99 | os.makedirs(output_dir, exist_ok=True) 100 | 101 | 102 | print(model_file) 103 | if not os.path.exists(model_file): 104 | print(f"ERROR: model file {model_file} does not exist!") 105 | exit() 106 | 107 | model = torch.load(model_file) 108 | model.eval() 109 | args.clip_models = model.clip_models 110 | print("Loaded regression model trained on the following CLIP models:") 111 | print(args.clip_models) 112 | 113 | label_file = os.path.join(os.path.dirname(args.root_dir), os.path.basename(args.root_dir) + ".csv") 114 | if os.path.exists(label_file): 115 | database = pd.read_csv(label_file) 116 | print(f"Loaded existing database: {label_file}.\nDatabase contains {len(database)} entries") 117 | else: 118 | database = pd.DataFrame(columns=["uuid", "label", "timestamp", "predicted_label"]) 119 | print(f"Created new database file at {label_file}") 120 | 121 | # add new column 'predicted_label' to the database if not yet present: 122 | if 'predicted_label' not in database.columns: 123 | database['predicted_label'] = np.nan 124 | 125 | # Get all *.jpg files in the input_directory: 126 | img_files = [os.path.splitext(f)[0] for f in os.listdir(args.root_dir) if f.endswith('.jpg')] 127 | dataset = CustomDataset(img_files, model, args) 128 | dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, prefetch_factor=2) 129 | print(f"Predicting labels for {len(dataset)} images...") 130 | 131 | n_predictions = 0 132 | 133 | for uuids, img_paths, features in tqdm(dataloader): 134 | uuids = list(uuids) 135 | predicted_labels = model(features.to(device).float()).cpu().numpy().squeeze() 136 | 137 | # filter out samples that have an empty string as uuid (those files caused errors in the dataloader): 138 | remove_indices = [i for i, uuid in enumerate(uuids) if uuid == ""] 139 | uuids = [uuid for i, uuid in enumerate(uuids) if i not in remove_indices] 140 | img_paths = [img_path for i, img_path in enumerate(img_paths) if i not in remove_indices] 141 | predicted_labels = np.delete(predicted_labels, remove_indices) 142 | 143 | # Create a DataFrame for the current batch 144 | current_timestamp = np.full_like(predicted_labels, int(time.time())) 145 | batch_data = pd.DataFrame({'uuid': uuids, 'predicted_label': predicted_labels, 'timestamp': current_timestamp}) 146 | 147 | # Merge the 'predicted_label' and 'timestamp' columns from batch_data into the database based on the 'uuid' column 148 | database = database.merge(batch_data[['uuid', 'predicted_label', 'timestamp']], on='uuid', how='outer', suffixes=('', '_new')) 149 | 150 | # Update the 'predicted_label' and 'timestamp' columns in the database with the new values 151 | database['predicted_label'] = database['predicted_label_new'].where(database['predicted_label_new'].notna(), database['predicted_label']) 152 | database['timestamp'] = database['timestamp_new'].where(database['timestamp_new'].notna(), database['timestamp']) 153 | 154 | # Drop the temporary columns created during the merge 155 | database.drop(columns=['predicted_label_new', 'timestamp_new'], inplace=True) 156 | n_predictions += len(uuids) 157 | 158 | # add the "predicted_label" to the original uuid.json file: 159 | for uuid, label in zip(uuids, predicted_labels): 160 | label = label.item() 161 | json_file = os.path.join(args.root_dir, uuid + '.json') 162 | if os.path.exists(json_file): 163 | with open(json_file, 'r') as f: 164 | data = json.load(f) 165 | data['predicted_label'] = label 166 | with open(json_file, 'w') as f: 167 | json.dump(data, f) 168 | 169 | 170 | if args.copy_imgs_fraction > 0: # copy a random fraction of the images to the output directory 171 | indices = np.arange(len(uuids)) 172 | random_indices = indices[np.random.random(len(uuids)) < args.copy_imgs_fraction] 173 | src_paths = [img_paths[i] for i in random_indices] 174 | dst_paths = [f"{predicted_labels[i]:.3f}_{uuids[i]}.jpg" for i in random_indices] 175 | 176 | for src, dst in zip(src_paths, dst_paths): 177 | shutil.copy(src, os.path.join(output_dir, dst)) 178 | 179 | if n_predictions % 100 == 0: 180 | database.to_csv(label_file, index=False) 181 | 182 | database.to_csv(label_file, index=False) 183 | plot_label_distribution(database, args) 184 | 185 | print("Done!") 186 | print(f"{n_predictions} of {len(img_files)} img predicted. (the rest was skipped due to errors)") 187 | print(f"Average predicted label: {database['predicted_label'].mean():.3f}") 188 | print(f"Database saved at {label_file}") 189 | 190 | 191 | 192 | if __name__ == "__main__": 193 | parser = argparse.ArgumentParser() 194 | parser.add_argument('--root_dir', type=str, help='Root directory of the dataset') 195 | parser.add_argument('--model_file', type=str, help='Path to the model file (.pth)') 196 | parser.add_argument('--batch_size', type=int, default=12, help='Batch size for predicting') 197 | parser.add_argument('--copy_imgs_fraction', type=float, default=0.01, help='Fraction of images to copy to tmp_output directory with prepended prediction score') 198 | parser.add_argument('--num_workers', type=int, default=4, help='Number of workers to use for the dataloader') 199 | args = parser.parse_args() 200 | 201 | mp.set_start_method('spawn') 202 | 203 | # recursively apply the model to all subdirectories: 204 | for root, dirs, files in os.walk(args.root_dir): 205 | jpg_files = [f for f in files if f.endswith('.jpg')] 206 | 207 | if len(jpg_files) > 0 and "_predicted_scores" not in root: 208 | args.root_dir = root 209 | print(f"\n\nPredicting labels for {root}...") 210 | predict_labels(args) -------------------------------------------------------------------------------- /_6_create_subset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from tqdm import tqdm 4 | import argparse 5 | import pandas as pd 6 | from PIL import Image 7 | 8 | def copy_data(args, output_suffix = '_subset'): 9 | ''' 10 | Copy all the files in the root_dir based on predicted label 11 | ''' 12 | 13 | # Get all the rows where the predicted label is above the threshold: 14 | database_path = os.path.join(os.path.dirname(args.input_dir), os.path.basename(args.input_dir) + ".csv") 15 | database = pd.read_csv(database_path) 16 | print(f"Loaded database with {len(database)} rows") 17 | 18 | # get the maximum actual label value: 19 | max_actual_label = database["label"].max() 20 | print(f"Max actual label: {max_actual_label}") 21 | 22 | # Define a function to apply the filtering criteria 23 | def filter_rows(row): 24 | scaling_f = 1 / max_actual_label 25 | final_label = row["label"] * scaling_f if pd.notnull(row["label"]) else row["predicted_label"] 26 | return args.min_score <= final_label <= args.max_score 27 | 28 | # Filter the DataFrame using the function 29 | database = database[database.apply(filter_rows, axis=1)] 30 | print(f"Found {len(database)} rows with {args.min_score} < final_label < {args.max_score}") 31 | 32 | output_suffix = f'_{args.min_score:.2f}_to_{args.max_score:.2f}' + output_suffix 33 | output_folder = os.path.join(args.input_dir + output_suffix) 34 | 35 | if args.test: 36 | print(f"##### Running script in TEST MODE: Not actually copying any files #####") 37 | else: 38 | os.makedirs(output_folder, exist_ok=True) 39 | 40 | # Loop over the uuids in the database and copy the corresponding files to the output folder: 41 | print(f"Copying files to {output_folder}...") 42 | counter = [0] * len(args.extensions) 43 | for uuid in tqdm(database["uuid"].values): 44 | # get the corresponding img path for this uuid: 45 | img_path = os.path.join(args.input_dir, uuid + ".jpg") 46 | try: 47 | with Image.open(img_path) as img: 48 | width, height = img.size 49 | aspect_ratio = width / height 50 | except Exception as e: 51 | print(f"Could not open {img_path}, {str(e)}") 52 | continue 53 | 54 | # check if the img is within the aspect ratio and pixel size range: 55 | if aspect_ratio < args.min_aspect_ratio or aspect_ratio > args.max_aspect_ratio or (width*height) <= args.min_n_pixels: 56 | continue 57 | 58 | for ext in args.extensions: 59 | filename = uuid + ext 60 | input_path = os.path.join(args.input_dir, filename) 61 | output_path = os.path.join(output_folder, filename) 62 | if not args.test and os.path.exists(input_path): 63 | shutil.copy(input_path, output_path) 64 | counter[args.extensions.index(ext)] += 1 65 | 66 | for ext, count in zip(args.extensions, counter): 67 | print(f"Copied {count} files with extension {ext} to {output_folder}") 68 | 69 | if not args.test: 70 | # count the total number of img files in the output folder: 71 | img_extensions = ('.jpg', '.jpeg', '.png') 72 | n_img_files = len([f for f in os.listdir(output_folder) if f.endswith(img_extensions)]) 73 | # append the total number of imgs to the output foldername: 74 | os.rename(output_folder, output_folder + f"_{n_img_files}_imgs") 75 | 76 | if __name__ == "__main__": 77 | parser = argparse.ArgumentParser() 78 | parser.add_argument('--input_dir', type=str, help='Input directory') 79 | parser.add_argument('--min_score', type=float, help='minimum score to copy') 80 | parser.add_argument('--max_score', type=float, default=1.0, help='Maximum score to copy') 81 | parser.add_argument('--extensions', nargs='+', default=['.jpg', '.txt', '.pt', '.pth'], help='Extensions to copy') 82 | parser.add_argument('--min_aspect_ratio', type=float, default=0.25, help='Minimum aspect ratio of imgs to copy') 83 | parser.add_argument('--max_aspect_ratio', type=float, default=4.00, help='Maximum aspect ratio of imgs to copy') 84 | parser.add_argument('--min_n_pixels', type=int, default=(512*512), help='Minimum number of total pixels of imgs to copy') 85 | parser.add_argument('--test', action='store_true', help='Test mode, wont actually copy anything') 86 | args = parser.parse_args() 87 | 88 | copy_data(args) -------------------------------------------------------------------------------- /investigate_embedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | 4 | def print_structure(data, indent=0): 5 | """Recursively prints the structure of nested dictionaries containing tensors.""" 6 | prefix = " " * indent 7 | if isinstance(data, dict): 8 | for key, value in data.items(): 9 | print(f"{prefix}Key: {key}") 10 | if isinstance(value, torch.Tensor): 11 | print(f"{prefix} Shape: {value.shape}, Dtype: {value.dtype}") 12 | elif isinstance(value, dict): 13 | print_structure(value, indent + 1) 14 | else: 15 | print(f"{prefix} Type: {type(value)}") 16 | elif isinstance(data, torch.Tensor): 17 | # Handle case where the .pt file directly contains a tensor 18 | print(f"{prefix}Tensor Shape: {data.shape}, Dtype: {data.dtype}") 19 | else: 20 | print(f"{prefix}Type: {type(data)}") 21 | 22 | 23 | if __name__ == "__main__": 24 | # --- IMPORTANT: Replace this with the actual path to your .pt file --- 25 | file_path = "/data/xander/Projects/cog/xander_eden_stuff/tmp/all/2b1e66dff81c4b70b54731ae08edb1b1.pt" 26 | # -------------------------------------------------------------------- 27 | 28 | if not os.path.exists(file_path): 29 | print(f"Error: File not found at {file_path}") 30 | print("Please update the 'file_path' variable in the script.") 31 | exit() 32 | 33 | print(f"Loading data from: {file_path}") 34 | try: 35 | # Load the file, mapping to CPU to avoid GPU issues if saved on GPU 36 | loaded_data = torch.load(file_path, map_location='cpu') 37 | print("\n--- File Contents ---") 38 | print_structure(loaded_data) 39 | print("--------------------\n") 40 | 41 | except Exception as e: 42 | print(f"Error loading or processing file {file_path}: {e}") -------------------------------------------------------------------------------- /models/single_crop_regression_9.4k_imgs_80_epochs.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aiXander/CLIP_assisted_data_labeling/b87baa76c9a9c46f005b765eac0afbfd43bc2bcd/models/single_crop_regression_9.4k_imgs_80_epochs.pth -------------------------------------------------------------------------------- /predict_simple.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from PIL import Image 4 | from utils.embedder import AestheticRegressor 5 | from tqdm import tqdm 6 | import argparse 7 | import torch 8 | 9 | def predict_images(img_paths, model_path, device, output_dir = None): 10 | 11 | # Load the scoring model (only do this once in a python session): 12 | aesthetic_regressor = AestheticRegressor(model_path, device = device) 13 | 14 | if output_dir is not None: 15 | os.makedirs(output_dir, exist_ok = True) 16 | 17 | print("\nPredicting aesthetic scores...") 18 | for image_path in tqdm(img_paths): 19 | score, embedding = aesthetic_regressor.predict_score(Image.open(image_path)) 20 | print(f"Score: {score:.3f} for {os.path.basename(image_path)}") 21 | 22 | if output_dir is not None: 23 | output_path = os.path.join(output_dir, f'{score:.3f}_' + os.path.basename(image_path)) 24 | shutil.copy(image_path, output_path) 25 | 26 | 27 | if __name__ == "__main__": 28 | parser = argparse.ArgumentParser() 29 | 30 | # IO args: 31 | parser.add_argument('--input_img_dir', type=str, help='Root directory of the (optionally multiple) datasets') 32 | parser.add_argument('--model_path', type=str, default='models/random_c_2024-12-10_11:34:22_4.8k_imgs_30_epochs_-1.0000_mse.pth', help='Path to the model file (.pth)') 33 | args = parser.parse_args() 34 | 35 | input_img_dir = args.input_img_dir 36 | 37 | #output_dir = None # dont copy the scored images 38 | output_dir = input_img_dir + "_aesthetic_scores" # copy the scored images 39 | 40 | # Get all the img_paths: 41 | img_extensions = [".jpg", ".png", ".jpeg", ".bmp", ".webp"] 42 | list_of_img_paths = [os.path.join(input_img_dir, img_name) for img_name in os.listdir(input_img_dir) if os.path.splitext(img_name)[1].lower() in img_extensions] 43 | print(f"Found {len(list_of_img_paths)} images in {input_img_dir}") 44 | 45 | device = "cuda" if torch.cuda.is_available() else "cpu" 46 | predict_images(list_of_img_paths, args.model_path, device, output_dir) 47 | 48 | 49 | -------------------------------------------------------------------------------- /tools/find_similar_imgs.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | from pathlib import Path 5 | from tqdm import tqdm 6 | import torch 7 | from pathlib import Path 8 | 9 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 10 | 11 | def get_filepaths(root_dir, extension = ['.pt']): 12 | filepaths = [] 13 | for root, dirs, files in os.walk(root_dir): 14 | for file in files: 15 | if file.endswith(tuple(extension)): 16 | filepaths.append(os.path.join(root, file)) 17 | return filepaths 18 | 19 | def create_context_embedding(args, context_dir): 20 | 21 | if args.clip_models_to_use[0] != "all": 22 | print(f"\n----> Using clip models: {args.clip_models_to_use}") 23 | 24 | # Load the feature vectors from disk (uuid.pt) 25 | print(f"\nLoading CLIP features from disk...") 26 | 27 | context_clip_features = [] 28 | context_pathnames = [] 29 | n_samples = 0 30 | skips = 0 31 | 32 | # Load all the clip_embeddings from the context dir: 33 | for embedding_path in tqdm(get_filepaths(context_dir)): 34 | 35 | try: 36 | full_feature_dict = torch.load(embedding_path) 37 | 38 | if args.clip_models_to_use[0] == "all": 39 | args.clip_models_to_use = list(full_feature_dict.keys()) 40 | print(f"\n----> Using all found clip models: {args.clip_models_to_use}") 41 | 42 | sample_features = [] 43 | 44 | for clip_model_name in args.clip_models_to_use: 45 | feature_dict = full_feature_dict[clip_model_name] 46 | clip_embedding = feature_dict[args.crop_name_to_use].flatten() 47 | sample_features.append(clip_embedding) 48 | 49 | context_clip_features.append(torch.cat(sample_features, dim=0)) 50 | context_pathnames.append(Path(embedding_path).name) 51 | n_samples += 1 52 | except Exception as e: # simply skip the sample if something goes wrong 53 | print(e) 54 | skips += 1 55 | continue 56 | 57 | print(f"Loaded {n_samples} samples from {context_dir}") 58 | if skips > 0: 59 | print(f"(skipped {skips} samples due to loading errors)..") 60 | 61 | context_clip_features = torch.stack(context_clip_features, dim=0).to(device).float() 62 | print("CLIP features for context loaded, shape: ", context_clip_features.shape) 63 | return torch.mean(context_clip_features, dim=0), context_pathnames 64 | 65 | 66 | 67 | class topN(): 68 | # Class to keep track of the top N most similar uuids and their corresponding distances 69 | def __init__(self, top_n): 70 | self.top_n = top_n 71 | self.best_img_paths = [] 72 | self.best_distances = [] 73 | 74 | def update(self, distance, img_path): 75 | 76 | if len(self.best_distances) < self.top_n: 77 | self.best_img_paths.append(img_path) 78 | self.best_distances.append(distance) 79 | else: 80 | # find the index of the img_path with the largest distance (using torch): 81 | idx = torch.tensor(self.best_distances).argmax().item() 82 | 83 | if distance < self.best_distances[idx]: 84 | self.best_img_paths[idx] = img_path 85 | self.best_distances[idx] = distance 86 | 87 | 88 | def compute_distance(context_clip_embedding, sample_clip_embedding, similarity_measure): 89 | if similarity_measure == "cosine": 90 | return (1-torch.nn.functional.cosine_similarity(context_clip_embedding, sample_clip_embedding, dim=-1))/2 91 | elif similarity_measure == "l2": 92 | return torch.nn.functional.pairwise_distance(context_clip_embedding, sample_clip_embedding, p=2, eps=1e-06) 93 | else: 94 | raise NotImplementedError(f"Similarity measure {similarity_measure} not implemented!") 95 | 96 | def find_similar_imgs(args, context_clip_embedding, context_pathnames): 97 | 98 | # Load the feature vectors from disk (uuid.pt) 99 | print(f"\nSearching {args.search_dir} for similar imgs. Saving results to {args.output_dir}..") 100 | 101 | n_samples = 0 102 | skips = 0 103 | 104 | topn = topN(args.top_n) 105 | 106 | # Load all the clip_embeddings from the context dir: 107 | for embedding_path in tqdm(get_filepaths(args.search_dir)): 108 | 109 | # Make sure there is a corresponding image file: 110 | img_path = embedding_path.replace(".pt", ".jpg") 111 | 112 | if not os.path.exists(img_path) or (Path(img_path).name in context_pathnames): 113 | continue 114 | 115 | try: 116 | full_feature_dict = torch.load(embedding_path) 117 | sample_features = [] 118 | for clip_model_name in args.clip_models_to_use: 119 | feature_dict = full_feature_dict[clip_model_name] 120 | clip_embedding = feature_dict[args.crop_name_to_use].flatten() 121 | sample_features.append(clip_embedding) 122 | 123 | sample_clip_embedding = torch.cat(sample_features, dim=0) 124 | 125 | d = compute_distance(context_clip_embedding, sample_clip_embedding, args.similarity_measure) 126 | topn.update(d, img_path) 127 | n_samples += 1 128 | except Exception as e: # simply skip the sample if something goes wrong 129 | print(e) 130 | skips += 1 131 | continue 132 | 133 | print(f"Searched through {n_samples} samples from {args.search_dir}") 134 | if skips > 0: 135 | print(f"(skipped {skips} samples due to loading errors)..") 136 | 137 | return topn 138 | 139 | 140 | if __name__ == "__main__": 141 | parser = argparse.ArgumentParser(description="Find similar images between the context and search directories using pre-computed CLIP embeddings") 142 | parser.add_argument("--context_dir", help="Directory to learn img context from") 143 | parser.add_argument("--search_dir", help="Directory to find similar imgs in") 144 | parser.add_argument("--output_dir", default=None, help="Directory to copy selected files to (default: search_dir_similar)") 145 | parser.add_argument('--clip_models_to_use', metavar='S', type=str, nargs='+', default=['all'], help='Which CLIP model embeddings to use, default: use all found') 146 | parser.add_argument("--crop_name_to_use", default="square_padded_crop", help="From which img crop to use the CLIP embedding") 147 | parser.add_argument("--similarity_measure", default="l2", help="Similarity measure to use in CLIP-space (cosine or l2)") 148 | parser.add_argument("--top_n", default=30, type=int, help="How many similar images to find") 149 | args = parser.parse_args() 150 | 151 | # check if the context dir contains .pt files, if not, it's a root dir, loop over its subdirs: 152 | if not any([f.endswith(".pt") for f in os.listdir(args.context_dir)]): 153 | context_dirs = [os.path.join(args.context_dir, d) for d in os.listdir(args.context_dir)] 154 | else: 155 | context_dirs = [args.context_dir] 156 | 157 | for context_dir in context_dirs: 158 | # Create CLIP-embedding of the context images: 159 | context_clip_embedding, context_pathnames = create_context_embedding(args, context_dir) 160 | 161 | # Create the output dir if it doesn't exist: 162 | args.output_dir = os.path.join(context_dir, "_similar") 163 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 164 | 165 | # Find the n most similar images in the search dir: 166 | topn = find_similar_imgs(args, context_clip_embedding, context_pathnames) 167 | 168 | # Move the best images to the output dir: 169 | for i, img_path in enumerate(topn.best_img_paths): 170 | distance = topn.best_distances[i] 171 | orig_stem = Path(img_path).stem 172 | out_path = os.path.join(args.output_dir, f"{distance:.3f}_{orig_stem}.jpg") 173 | shutil.copy(img_path, out_path) -------------------------------------------------------------------------------- /tools/fix_img_dir.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from PIL import Image 4 | 5 | def process_images(src_folder, tmp_folder): 6 | if not os.path.exists(tmp_folder): 7 | os.makedirs(tmp_folder) 8 | 9 | for file in os.listdir(src_folder): 10 | if file.lower().endswith('.jpg'): 11 | file_path = os.path.join(src_folder, file) 12 | 13 | try: 14 | with Image.open(file_path) as img: 15 | # Perform any image processing here if needed 16 | print(f"Successfully opened {file}") 17 | except Exception as e: 18 | print(f"Error opening {file}: {e}") 19 | dest_path = os.path.join(tmp_folder, file) 20 | shutil.move(file_path, dest_path) 21 | print(f"Moved {file} to the tmp folder") 22 | 23 | if __name__ == "__main__": 24 | target_folder = "/home/xander/Projects/cog/eden-sd-pipelines/eden/xander/assets/gordon/combo" 25 | tmp_directory = "/home/xander/Projects/cog/eden-sd-pipelines/eden/xander/assets/gordon/combo_errored" 26 | 27 | process_images(target_folder, tmp_directory) 28 | -------------------------------------------------------------------------------- /tools/move_subset_of_files.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | import random 5 | from pathlib import Path 6 | from tqdm import tqdm 7 | 8 | def crawl_directory(root_dir, file_extensions): 9 | files = {} 10 | for dirpath, dirnames, filenames in os.walk(root_dir): 11 | for filename in filenames: 12 | if any(filename.endswith(ext) for ext in file_extensions): 13 | basename = os.path.splitext(filename)[0] 14 | if basename not in files: 15 | files[basename] = [] 16 | files[basename].append(os.path.join(dirpath, filename)) 17 | return files 18 | 19 | def copy_files(files, out_dir, fraction_f): 20 | n_copied_samples = 0 21 | for basename, paths in tqdm(files.items()): 22 | if random.random() < fraction_f: 23 | n_copied_samples += 1 24 | for path in paths: 25 | dest_path = os.path.join(out_dir, os.path.relpath(path, root_dir)) 26 | os.makedirs(os.path.dirname(dest_path), exist_ok=True) 27 | shutil.copy2(path, dest_path) 28 | 29 | print(f"Copied {n_copied_samples} samples to {out_dir}") 30 | 31 | if __name__ == "__main__": 32 | parser = argparse.ArgumentParser(description="Copy a fraction of files with specified extensions to out_dir") 33 | parser.add_argument("--root_dir", help="Directory to crawl for files") 34 | parser.add_argument("--out_dir", default=None, help="Directory to copy selected files to (default: same as root_dir)") 35 | parser.add_argument("--fraction_f", type=float, default=0.01, help="Fraction of files to copy (default: 0.001)") 36 | parser.add_argument("--file_extensions", nargs="+", default=['.jpg'], help="List of file extensions to consider (default: .jpg)") 37 | args = parser.parse_args() 38 | 39 | # Removing any possible trailing / from root_dir: 40 | args.root_dir = str(Path(args.root_dir).resolve()) 41 | 42 | if args.out_dir is None: 43 | args.out_dir = args.root_dir + f"_{args.fraction_f:.3f}_subset" 44 | 45 | root_dir = args.root_dir 46 | out_dir = args.out_dir 47 | fraction_f = args.fraction_f 48 | file_extensions = args.file_extensions 49 | 50 | files = crawl_directory(root_dir, file_extensions) 51 | copy_files(files, out_dir, fraction_f) -------------------------------------------------------------------------------- /utils/embedder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | import torchvision.transforms as transforms 4 | from torchvision import models 5 | from torch.utils.data import Dataset 6 | import open_clip 7 | import os, sys 8 | import random 9 | from PIL import Image 10 | import numpy as np 11 | 12 | # Hardcoded hack, TODO: clean this up: 13 | pe_path = "../perception_models" 14 | sys.path.append(pe_path) 15 | import core.vision_encoder.pe as pe 16 | import core.vision_encoder.transforms as pe_transforms 17 | 18 | from .image_features import ImageFeaturizer 19 | 20 | _DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' 21 | 22 | def extract_vgg_features(image, model_name='vgg', layer_index=10): 23 | # Load pre-trained model 24 | if model_name == 'vgg': 25 | model = models.vgg16(pretrained=True).features 26 | elif model_name == 'alexnet': 27 | model = models.alexnet(pretrained=True).features 28 | else: 29 | raise ValueError('Invalid model name. Choose "vgg" or "alexnet".') 30 | 31 | # Set model to evaluation mode 32 | model.eval() 33 | 34 | # Extract features up to the specified layer 35 | model = torch.nn.Sequential(*list(model.children())[:layer_index+1]) 36 | 37 | # Define image transformation 38 | transform = transforms.Compose([ 39 | transforms.Resize(256), 40 | transforms.CenterCrop(224), 41 | transforms.ToTensor(), 42 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 43 | ]) 44 | 45 | image = transform(image).unsqueeze(0) 46 | 47 | # Move image to GPU if available 48 | image = image.to(_DEVICE) 49 | model = model.to(_DEVICE) 50 | 51 | # Extract features 52 | with torch.no_grad(): 53 | features = model(image) 54 | 55 | return features 56 | 57 | 58 | class CLIP_Encoder: 59 | def __init__(self, model_name, model_path=None, device=None): 60 | self.device = device if device else _DEVICE 61 | self.precision = 'fp16' if self.device == 'cuda' else 'fp32' 62 | self.model_name = model_name 63 | self.model_architecture, self.pretrained_dataset = self.model_name.split('/', 2) 64 | 65 | print(f"Loading CLIP model {self.model_name}...") 66 | self.model, _, self.preprocess = open_clip.create_model_and_transforms( 67 | self.model_architecture, 68 | pretrained=self.pretrained_dataset, 69 | precision=self.precision, 70 | device=self.device, 71 | jit=False, 72 | cache_dir=model_path 73 | ) 74 | self.model = self.model.to(self.device).eval() 75 | # Extract resolution from preprocess transforms 76 | self.img_resolution = 224 # Default 77 | for t in self.preprocess.transforms: 78 | if isinstance(t, transforms.Resize): 79 | # open_clip uses int or tuple for size. If int, it's square. 80 | size = t.size 81 | if isinstance(size, int): 82 | self.img_resolution = size 83 | elif isinstance(size, (list, tuple)) and len(size) >= 1: 84 | self.img_resolution = size[0] # Assume square if tuple/list 85 | break 86 | 87 | 88 | print(f"CLIP model {self.model_name} with img_resolution {self.img_resolution} loaded on {self.device}!") 89 | 90 | def get_preprocess_transform(self): 91 | # Return the full preprocessing pipeline from open_clip 92 | return self.preprocess 93 | 94 | @torch.no_grad() 95 | def encode_image(self, preprocessed_images: torch.Tensor) -> torch.Tensor: 96 | if self.precision == 'fp16': 97 | preprocessed_images = preprocessed_images.half() 98 | image_features = self.model.encode_image(preprocessed_images) 99 | image_features /= image_features.norm(dim=-1, keepdim=True) 100 | return image_features 101 | 102 | 103 | class PE_Encoder: 104 | def __init__(self, model_name, device=None): 105 | self.device = device if device else _DEVICE 106 | self.model_name = model_name 107 | 108 | print(f"Loading PE model {self.model_name}...") 109 | self.model = pe.CLIP.from_config(self.model_name, pretrained=True) 110 | self.model = self.model.to(self.device).eval() 111 | self.img_resolution = self.model.image_size 112 | self.context_length = self.model.context_length # Needed for tokenizer later if text is used 113 | 114 | # Get the appropriate image transform 115 | # self.preprocess = pe_transforms.get_image_transform(self.img_resolution) # Removed due to potential lambda/pickle issues 116 | # Define preprocessing directly to ensure pickle compatibility 117 | self.preprocess = transforms.Compose([ 118 | transforms.Resize(self.img_resolution, interpolation=transforms.InterpolationMode.BICUBIC), 119 | transforms.CenterCrop(self.img_resolution), 120 | transforms.ToTensor(), 121 | transforms.Normalize( 122 | mean=(0.48145466, 0.4578275, 0.40821073), # Assuming standard CLIP/PE mean 123 | std=(0.26862954, 0.26130258, 0.27577711), # Assuming standard CLIP/PE std 124 | ), 125 | ]) 126 | # self.tokenizer = pe_transforms.get_text_tokenizer(self.context_length) # If text needed 127 | 128 | print(f"PE model {self.model_name} with img_resolution {self.img_resolution} loaded on {self.device}!") 129 | 130 | def get_preprocess_transform(self): 131 | # Return the specific PE image transform 132 | return self.preprocess 133 | 134 | @torch.no_grad() 135 | @torch.autocast("cuda") # PE example uses autocast 136 | def encode_image(self, preprocessed_images: torch.Tensor) -> torch.Tensor: 137 | # PE model forward returns tuple (image_features, text_features, logit_scale) 138 | # We only need image_features for this task. PE might need dummy text. 139 | # Let's check the signature or assume None works for text if only image features needed. 140 | # From test.py, it seems PE model expects both image and text. 141 | # We need a way to get only image features. Let's assume model.encode_image exists, like CLIP. 142 | # If not, we'll need to adapt. Let's check PE source or documentation. 143 | # Assuming `encode_image` exists and works like open_clip: 144 | image_features = self.model.encode_image(preprocessed_images) # Placeholder assumption 145 | # If encode_image doesn't exist, we'd use: 146 | # dummy_text = self.tokenizer([""]).to(self.device) # Create dummy text input 147 | # image_features, _, _ = self.model(preprocessed_images, dummy_text) 148 | 149 | image_features /= image_features.norm(dim=-1, keepdim=True) 150 | return image_features 151 | 152 | 153 | class CustomImageDataset(Dataset): 154 | # Modified __init__ to accept preprocess_transform instead of img_resolution/device 155 | def __init__(self, image_paths, crop_names, preprocess_transform): 156 | self.image_paths = image_paths 157 | self.crop_names = crop_names 158 | self.preprocess_transform = preprocess_transform # Store the transform 159 | self.img_featurizer = ImageFeaturizer() 160 | 161 | def __len__(self): 162 | return len(self.image_paths) 163 | 164 | def __getitem__(self, idx): 165 | try: 166 | img_path = self.image_paths[idx] 167 | pil_img = Image.open(img_path).convert('RGB') 168 | # extract_crops now returns raw PIL crops or tensors before final preprocessing 169 | raw_crops, crop_names_list = self.extract_crops(pil_img) 170 | image_features = self.img_featurizer.process(np.array(pil_img)) 171 | 172 | # Apply the specific preprocessing transform to each crop 173 | processed_crops = torch.stack([self.preprocess_transform(crop) for crop in raw_crops]) 174 | 175 | return processed_crops, crop_names_list, img_path, image_features 176 | except Exception as e: 177 | print(f"Error loading or processing image {img_path}: {e}") 178 | # Return data from a random valid image instead 179 | random_idx = random.randint(0, len(self.image_paths)-1) 180 | print(f"Substituting with image index {random_idx}") 181 | return self.__getitem__(random_idx) 182 | 183 | # Modified extract_crops to return PIL images or basic tensors, preprocessing is done in __getitem__ 184 | def extract_crops(self, pil_img: Image) -> (list, list): 185 | img_tensor = transforms.ToTensor()(pil_img) # Keep as basic tensor initially 186 | c, h, w = img_tensor.shape 187 | 188 | raw_crops, crop_names_list = [], [] 189 | 190 | # Convert tensor back to PIL for potential transforms that expect PIL 191 | # Or adjust cropping logic if transforms handle tensors directly 192 | # Assuming preprocess_transform takes PIL: 193 | pil_img_for_crop = transforms.ToPILImage()(img_tensor) 194 | 195 | 196 | if 'centre_crop' in self.crop_names: 197 | crop_size = min(pil_img.width, pil_img.height) 198 | # Use torchvision transforms for cropping PIL images 199 | centre_crop_transform = transforms.CenterCrop(crop_size) 200 | centre_crop_pil = centre_crop_transform(pil_img_for_crop) 201 | raw_crops.append(centre_crop_pil) 202 | crop_names_list.append("centre_crop") 203 | 204 | if 'square_padded_crop' in self.crop_names: 205 | crop_size = max(pil_img.width, pil_img.height) 206 | # Create square padded PIL image 207 | square_padded_pil = Image.new("RGB", (crop_size, crop_size), (0, 0, 0)) 208 | start_h = (crop_size - pil_img.height) // 2 209 | start_w = (crop_size - pil_img.width) // 2 210 | square_padded_pil.paste(pil_img_for_crop, (start_w, start_h)) 211 | raw_crops.append(square_padded_pil) 212 | crop_names_list.append("square_padded_crop") 213 | 214 | 215 | if any('subcrop1' in name for name in self.crop_names) or any('subcrop2' in name for name in self.crop_names): 216 | subcrop_area_fractions = [0.15, 0.1] 217 | subcrop_w1 = int((pil_img.width * pil_img.height * subcrop_area_fractions[0]) ** 0.5) 218 | subcrop_h1 = subcrop_w1 219 | subcrop_w2 = int((pil_img.width * pil_img.height * subcrop_area_fractions[1]) ** 0.5) 220 | subcrop_h2 = subcrop_w2 221 | 222 | if pil_img.width >= pil_img.height: # wide / square img 223 | centers = [(pil_img.width // 4, pil_img.height // 2), (pil_img.width // 4 * 3, pil_img.height // 2)] 224 | else: # tall img 225 | centers = [(pil_img.width // 2, pil_img.height // 4), (pil_img.width // 2, pil_img.height // 4 * 3)] 226 | 227 | sizes = [(subcrop_w1, subcrop_h1), (subcrop_w2, subcrop_h2)] 228 | names = ['subcrop1', 'subcrop2'] 229 | 230 | for i, (center_w, center_h) in enumerate(centers): 231 | if names[i] in self.crop_names: 232 | width, height = sizes[i] 233 | left = max(0, center_w - width // 2) 234 | top = max(0, center_h - height // 2) 235 | right = min(pil_img.width, left + width) 236 | bottom = min(pil_img.height, top + height) 237 | 238 | # Adjust size if crop went out of bounds to maintain aspect ratio (or simply crop) 239 | # Using PIL's crop which handles bounds: box is (left, upper, right, lower) 240 | subcrop_pil = pil_img_for_crop.crop((left, top, right, bottom)) 241 | 242 | # Ensure the cropped area isn't empty due to rounding/bounds 243 | if subcrop_pil.width > 0 and subcrop_pil.height > 0: 244 | raw_crops.append(subcrop_pil) 245 | crop_names_list.append(names[i]) 246 | else: 247 | print(f"Warning: {names[i]} for image {idx} resulted in zero size.") 248 | 249 | 250 | # Return list of PIL images and their names 251 | return raw_crops, crop_names_list 252 | 253 | 254 | import time 255 | class Timer(): 256 | 'convenience class to time code' 257 | def __init__(self, name, start = False): 258 | self.name = name 259 | self.total_time_running = 0.0 260 | if start: 261 | self.start() 262 | 263 | def pause(self): 264 | self.total_time_running += time.time() - self.last_start 265 | 266 | def start(self): 267 | self.last_start = time.time() 268 | 269 | def status(self): 270 | print(f'{self.name} accumulated {self.total_time_running:.3f} seconds of runtime') 271 | 272 | def exit(self, *args): 273 | self.total_time_running += time.time() - self.last_start 274 | print(f'{self.name} took {self.total_time_running:.3f} seconds') 275 | 276 | 277 | class AestheticRegressor: 278 | """ 279 | Aesthetic Regressor to predict the aesthetic score of images. 280 | """ 281 | def __init__(self, model_path, device="cpu"): 282 | self.device = device 283 | self.load_model(model_path) 284 | 285 | # Load associated CLIP models 286 | self.clip_models = [CLIP_Encoder(name, device=self.device) for name in self.model.clip_models] 287 | 288 | @torch.no_grad() 289 | def load_model(self, model_path, verbose=1): 290 | self.model = torch.load(model_path, map_location = self.device).to(self.device).eval() 291 | if verbose: 292 | print("Loaded regression model") 293 | print(f"Aesthetic Regressor was trained on embeddings from CLIP models:") 294 | print(self.model.clip_models) 295 | print(f"Aesthetic Regressor used crops:") 296 | print(self.model.crop_names) 297 | 298 | @torch.no_grad() 299 | def predict_score(self, pil_img): 300 | all_img_features = [] 301 | 302 | for clip_model in self.clip_models: 303 | img_dataset = CustomImageDataset([pil_img], clip_model.crop_names, clip_model.get_preprocess_transform()) 304 | crops, _ = img_dataset.extract_crops(pil_img) 305 | features = clip_model.encode_image(crops).unsqueeze(0) # Add batch dimension 306 | all_img_features.append(features) 307 | 308 | features = torch.stack(all_img_features).flatten().unsqueeze(0) # Add batch dimension 309 | score = self.model(features.to(self.device).float()).item() 310 | 311 | return score, features 312 | -------------------------------------------------------------------------------- /utils/image_features.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.fftpack import dct 3 | from scipy.stats import entropy 4 | import cv2 5 | import os 6 | 7 | def colorfulness(numpy_img): 8 | # Split the image into its color channels 9 | (B, G, R) = cv2.split(numpy_img.astype("float")) 10 | 11 | # Compute rg = R - G 12 | rg = np.absolute(R - G) 13 | 14 | # Compute yb = 0.5 * (R + G) - B 15 | yb = np.absolute(0.5 * (R + G) - B) 16 | 17 | # Compute the mean and standard deviation of both `rg` and `yb` 18 | (rg_mean, rg_std) = (np.mean(rg), np.std(rg)) 19 | (yb_mean, yb_std) = (np.mean(yb), np.std(yb)) 20 | 21 | # Combine the mean and standard deviations 22 | std_root = np.sqrt((rg_std ** 2) + (yb_std ** 2)) 23 | mean_root = np.sqrt((rg_mean ** 2) + (yb_mean ** 2)) 24 | 25 | # Compute the colorfulness metric 26 | colorfulness = std_root + (0.3 * mean_root) 27 | 28 | return colorfulness / 100 29 | 30 | def image_entropy(image, _nbins = 256): 31 | """ 32 | approximates information content of an image 33 | low entropy are usually images with a lot of flat colors 34 | high entropy is usually images with white/gray noise, high frequency edges etc. 35 | """ 36 | histogram = cv2.calcHist([image], [0], None, [_nbins], [0, _nbins]) 37 | histogram /= histogram.sum() 38 | entropy = -np.sum(histogram * np.log2(histogram + np.finfo(float).eps)) 39 | entropy = entropy / np.log2(_nbins) 40 | return entropy 41 | 42 | def laplacian_variance(image, normalization_scale_factor=1e-4): 43 | """ 44 | similar to image entropy, but more sensitive to edges, can detect blurriness 45 | """ 46 | laplacian = cv2.Laplacian(image, cv2.CV_64F) 47 | variance = np.var(laplacian) 48 | normalized_variance = np.tanh(variance * normalization_scale_factor) 49 | return normalized_variance 50 | 51 | class ImageFeaturizer(): 52 | def __init__(self, max_n_pixels = 768*768): 53 | self.max_n_pixels = max_n_pixels 54 | 55 | def process(self, rgb_image, verbose = False): 56 | 57 | # resize the image to max_n_pixels: 58 | w,h = rgb_image.shape[:2] 59 | new_w, new_h = int(np.sqrt(self.max_n_pixels * w / h)), int(np.sqrt(self.max_n_pixels * h / w)) 60 | rgb_image = cv2.resize(rgb_image, (new_w, new_h), interpolation = cv2.INTER_AREA) 61 | gray_image = cv2.cvtColor(rgb_image, cv2.COLOR_BGR2GRAY) 62 | hsv_img = cv2.cvtColor(rgb_image, cv2.COLOR_BGR2HSV) 63 | 64 | feature_dict = { 65 | 'img_stat_width': rgb_image.shape[1] / 768, 66 | 'img_stat_height': rgb_image.shape[0] / 768, 67 | 'img_stat_aspect_ratio': rgb_image.shape[1] / rgb_image.shape[0], 68 | 'img_stat_mean_color': np.mean(rgb_image) / 255, 69 | 'img_stat_std_color': np.std(rgb_image) / 255, 70 | 'img_stat_mean_red': np.mean(rgb_image[:,:,0]) / 255, 71 | 'img_stat_mean_green': np.mean(rgb_image[:,:,1]) / 255, 72 | 'img_stat_mean_blue': np.mean(rgb_image[:,:,2]) / 255, 73 | 'img_stat_std_red': np.std(rgb_image[:,:,0]) / 255, 74 | 'img_stat_std_green': np.std(rgb_image[:,:,1]) / 255, 75 | 'img_stat_std_blue': np.std(rgb_image[:,:,2]) / 255, 76 | 'img_stat_mean_gray': np.mean(gray_image) / 255, 77 | 'img_stat_std_gray': np.std(gray_image) / 255, 78 | 'img_stat_mean_hue': np.mean(hsv_img[:,:,0]) / 255, 79 | 'img_stat_mean_sat': np.mean(hsv_img[:,:,1]) / 255, 80 | 'img_stat_mean_val': np.mean(hsv_img[:,:,2]) / 255, 81 | 'img_stat_std_hue': np.std(hsv_img[:,:,0]) / 255, 82 | 'img_stat_std_sat': np.std(hsv_img[:,:,1]) / 255, 83 | 'img_stat_std_val': np.std(hsv_img[:,:,2]) / 255, 84 | 'img_stat_colorfulness': colorfulness(rgb_image), 85 | 'img_stat_image_entropy': image_entropy(gray_image), 86 | 'img_stat_laplacian_variance': laplacian_variance(gray_image) 87 | } 88 | 89 | if verbose: 90 | print("-----------------------------") 91 | for key, value in feature_dict.items(): 92 | print(f'{key}: {value:.4f}') 93 | 94 | return feature_dict 95 | 96 | if __name__ == '__main__': 97 | folder = "/home/rednax/SSD2TB/Fast_Datasets/SD/Labeling/datasets/todo" 98 | output_folder = "/home/rednax/SSD2TB/Fast_Datasets/SD/Labeling/datasets/todo_color" 99 | extensions = [".jpg", ".png"] 100 | 101 | os.makedirs(output_folder, exist_ok = True) 102 | 103 | # get all img_paths in folder: 104 | image_paths = [] 105 | for root, dirs, files in os.walk(folder): 106 | for file in files: 107 | if os.path.splitext(file)[1] in extensions: 108 | image_paths.append(os.path.join(root, file)) 109 | 110 | for image_path in image_paths: 111 | image = cv2.imread(image_path) 112 | featurizer = ImageFeaturizer() 113 | features = featurizer.process(image, verbose = True) -------------------------------------------------------------------------------- /utils/merge_datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import shutil 4 | 5 | """ 6 | 7 | cd /home/rednax/SSD2TB/Fast_Datasets/SD/Labeling 8 | python3 merge_datasets.py 9 | 10 | merge all datasets in the data_dir (all subfolders and their corresponding .csv files) into two datasets: labeled and unlabeled 11 | 12 | """ 13 | 14 | 15 | # Constants 16 | data_dir = "/home/rednax/SSD2TB/Fast_Datasets/SD/Labeling/test" 17 | output_dir = "/home/rednax/SSD2TB/Fast_Datasets/SD/Labeling/merged" 18 | 19 | labeled_dir = os.path.join(output_dir, "labeled") 20 | unlabeled_dir = os.path.join(output_dir, "unlabeled") 21 | 22 | # Create the output directories if they don't exist 23 | os.makedirs(labeled_dir, exist_ok=True) 24 | os.makedirs(unlabeled_dir, exist_ok=True) 25 | 26 | # List of dataframes, one per csv file 27 | dfs = [] 28 | 29 | # Iterate over the subdirectories 30 | for subdir in os.listdir(data_dir): 31 | subdir_path = os.path.join(data_dir, subdir) 32 | 33 | if os.path.isdir(subdir_path): 34 | # Load the csv file associated with the subdirectory 35 | csv_path = os.path.join(data_dir, f"{subdir}.csv") 36 | 37 | if os.path.exists(csv_path): 38 | df = pd.read_csv(csv_path) 39 | 40 | # Add a column with the name of the subdirectory: 41 | df['source_datadir'] = subdir 42 | dfs.append(df) 43 | 44 | # Concatenate all dataframes 45 | combined_df = pd.concat(dfs) 46 | 47 | # Split into labeled and unlabeled 48 | labeled_df = combined_df[combined_df['label'].notna()] 49 | unlabeled_df = combined_df[combined_df['label'].isna()] 50 | 51 | # Save as CSV files 52 | labeled_df.to_csv(os.path.join(output_dir, "labeled.csv"), index=False) 53 | unlabeled_df.to_csv(os.path.join(output_dir, "unlabeled.csv"), index=False) 54 | 55 | # Function to move files based on uuids 56 | def move_files(df, source_dir, destination_dir, extensions_to_move = ['.jpg', '.json', '.txt', '.pt', '.pth']): 57 | moved = 0 58 | uuids = df['uuid'].values 59 | source_dirs = df['source_datadir'].values 60 | 61 | for i, uuid in enumerate(uuids): 62 | for extension in extensions_to_move: 63 | # Here we assume that the file extension is .jpg, change it as per your dataset 64 | 65 | source_file = os.path.join(source_dir, source_dirs[i], f"{uuid}{extension}") 66 | 67 | if os.path.exists(source_file): 68 | destination_file = os.path.join(destination_dir, f"{uuid}{extension}") 69 | shutil.move(source_file, destination_file) 70 | moved += 1 71 | 72 | print(f"Moved {moved} files from {source_dir} to {destination_dir}!") 73 | 74 | # Move labeled and unlabeled files to their respective directories 75 | move_files(labeled_df, data_dir, labeled_dir) 76 | move_files(unlabeled_df, data_dir, unlabeled_dir) 77 | -------------------------------------------------------------------------------- /utils/nn_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 5 | 6 | class SimpleFC(nn.Module): 7 | def __init__(self, input_size, hidden_sizes, output_size, clip_models, 8 | crop_names = ['centre_crop', 'square_padded_crop', 'subcrop1', 'subcrop2'], 9 | use_img_stat_features = False, 10 | dropout_prob=0.0, 11 | data_min = None, data_max = None, 12 | verbose = 0): 13 | 14 | super(SimpleFC, self).__init__() 15 | self.clip_models = clip_models # Names of the clip models that were used to encode the images 16 | self.crop_names = crop_names 17 | self.use_img_stat_features = use_img_stat_features 18 | layer_sizes = [input_size] + hidden_sizes + [output_size] 19 | self.data_min, self.data_max = data_min, data_max 20 | 21 | # Define the network: 22 | layers = [] 23 | for i in range(len(layer_sizes) - 1): 24 | layers.append(nn.Linear(layer_sizes[i], layer_sizes[i+1])) 25 | # add ReLU and Dropout after each layer except the last one 26 | if i < len(layer_sizes) - 2: 27 | layers.append(nn.LeakyReLU()) 28 | layers.append(nn.Dropout(p=dropout_prob)) 29 | 30 | # Add sigmoid at the end: (assumes that all labels are normalized in the range [0,1]) 31 | layers.append(nn.Sigmoid()) 32 | 33 | self.layers = nn.ModuleList(layers) 34 | 35 | if verbose > 0: # Print the final network layout: 36 | print(self) 37 | 38 | def forward(self, x): 39 | for layer in self.layers: 40 | x = layer(x) 41 | return x 42 | 43 | 44 | 45 | class SimpleconvFC(nn.Module): 46 | def __init__(self, input_size, hidden_sizes, output_size, 47 | crop_names = ['centre_crop', 'square_padded_crop', 'subcrop1', 'subcrop2'], 48 | use_img_stat_features = False, 49 | dropout_prob=0.0, 50 | data_min = None, data_max = None, 51 | verbose = 0, 52 | conv_out_channels = 64, # added new parameter for conv output channels 53 | kernel_size = 5): # added new parameter for kernel size 54 | 55 | super(SimpleconvFC, self).__init__() 56 | self.crop_names = crop_names 57 | self.use_img_stat_features = use_img_stat_features 58 | self.data_min, self.data_max = data_min, data_max 59 | 60 | # Define the 1D Convolutional layer: 61 | input_size = 768*2 62 | self.conv1 = nn.Conv1d(input_size, conv_out_channels, kernel_size) 63 | 64 | # Adjust input size for FC layers based on conv output: 65 | layer_sizes = [4672] + hidden_sizes + [output_size] 66 | 67 | # Define the network: 68 | layers = [] 69 | for i in range(len(layer_sizes) - 1): 70 | layers.append(nn.Linear(layer_sizes[i], layer_sizes[i+1])) 71 | # add ReLU and Dropout after each layer except the last one 72 | if i < len(layer_sizes) - 2: 73 | layers.append(nn.ReLU()) 74 | layers.append(nn.Dropout(p=dropout_prob)) 75 | 76 | # Add sigmoid at the end: (assumes that all labels are normalized in the range [0,1]) 77 | layers.append(nn.Sigmoid()) 78 | 79 | self.layers = nn.ModuleList(layers) 80 | 81 | if verbose > 0: # Print the final network layout: 82 | print(self) 83 | 84 | def forward(self, x, verbose = 0): 85 | #x = x.view(x.size(0), 2, 77, 768) 86 | 87 | # Reshape to make 77 the last (temporal dimension), and concatenate the 2 channels (c and uc): into 2*768 features: 88 | x = x.permute(0, 1, 3, 2).reshape(x.size(0), 2*768, 77) 89 | 90 | if verbose: 91 | print("Pre conv:") 92 | print(x.shape) 93 | 94 | x = self.conv1(x) 95 | 96 | if verbose: 97 | print("Post conv, pre flatten:") 98 | print(x.shape) 99 | 100 | x = x.view(x.size(0), -1) # flatten the tensor 101 | 102 | if verbose: 103 | print("Post conv, post flatten:") 104 | print(x.shape) 105 | 106 | for layer in self.layers: 107 | x = layer(x) 108 | return x -------------------------------------------------------------------------------- /utils/train_latent_regressor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pandas as pd 4 | import argparse 5 | import pickle 6 | from tqdm import tqdm 7 | import torch 8 | from torch import nn 9 | from torch.utils.data import DataLoader, Dataset, random_split 10 | from torch.optim import Adam 11 | import matplotlib.pyplot as plt 12 | from nn_model import device, SimpleFC, SimpleconvFC 13 | from sklearn.metrics import r2_score 14 | 15 | """ 16 | This is an unfinished exerpiment: 17 | given a dataset of StableDiffusion prompt_embeds / aesthetic scores, try to learn a mapping from prompt_embeds to scores 18 | The idea here is to use this regressor to do prompt-augmentation in latent space. 19 | 20 | cd /home/rednax/SSD2TB/Xander_Tools/CLIP_assisted_data_labeling/utils 21 | python train_latent_regressor.py --train_data_dir /home/rednax/SSD2TB/Github_repos/cog/eden-sd-pipelines/eden/xander/images/random_c_uc_fin_dataset --train_data_names no_lora eden eden2 --model_name c_uc_regressor --test_fraction 0.4 --dont_save 22 | 23 | 24 | """ 25 | 26 | def train(args): 27 | 28 | torch.manual_seed(args.random_seed) 29 | np.random.seed(args.random_seed) 30 | 31 | features = [] 32 | labels = [] 33 | 34 | # Load all the labeled training data from disk: 35 | for train_data_name in args.train_data_names: 36 | n_samples = 0 37 | skips = 0 38 | 39 | # Load the labels and uuid's from labels.csv 40 | data = pd.read_csv(os.path.join(args.train_data_dir, train_data_name + '.csv')) 41 | # Drop all the rows where "label" is NaN: 42 | #data = data.dropna(subset=["label"]) 43 | # randomly shuffle the data: 44 | data = data.sample(frac=1).reset_index(drop=True) 45 | 46 | # Load the prompt_embeds from disk (uuid.pth) 47 | print(f"\nLoading {train_data_name} features from disk...") 48 | for index, row in tqdm(data.iterrows()): 49 | try: 50 | uuid = row["uuid"] 51 | # load row["label"] is it exists, otherwise load row["predicted_label"] 52 | label = row["label"] if not np.isnan(row["label"]) else row["predicted_label"]*0.5 53 | prompt_embeds = torch.load(f"{args.train_data_dir}/{train_data_name}/{uuid}.pth") 54 | features.append(prompt_embeds.flatten()) 55 | labels.append(label) 56 | n_samples += 1 57 | except: # simply skip the sample if something goes wrong 58 | skips += 1 59 | continue 60 | 61 | print(f"Loaded {n_samples} samples from {train_data_name}!") 62 | if skips > 0: 63 | print(f"(skipped {skips} samples due to loading errors)..") 64 | 65 | features = torch.stack(features, dim=0).to(device).float() 66 | labels = torch.tensor(labels).to(device).float() 67 | 68 | # Map the labels to 0-1: 69 | print("Normalizing labels to [0,1]...") 70 | labels_min, labels_max = labels.min(), labels.max() 71 | labels = (labels - labels_min) / (labels_max - labels_min) 72 | 73 | print("\n--- All data loaded ---") 74 | print("Features shape:", features.shape) 75 | print("Labels shape:", labels.shape) 76 | 77 | # 2. Create train and test dataloaders 78 | class RegressionDataset(Dataset): 79 | def __init__(self, features, labels): 80 | self.features = features 81 | self.labels = labels 82 | 83 | def __len__(self): 84 | return len(self.features) 85 | 86 | def __getitem__(self, idx): 87 | return self.features[idx], self.labels[idx] 88 | 89 | dataset = RegressionDataset(features, labels) 90 | train_size = int((1-args.test_fraction) * len(dataset)) 91 | test_size = len(dataset) - train_size 92 | 93 | print(f"Training on {train_size} samples, testing on {test_size} samples.") 94 | 95 | train_dataset, test_dataset = random_split(dataset, [train_size, test_size]) 96 | train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True) 97 | test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False) 98 | 99 | # 3. Create the network 100 | model = SimpleFC(features.shape[1], args.hidden_sizes, 1, 101 | dropout_prob=args.dropout_prob, 102 | verbose = args.print_network_layout, 103 | data_min = labels_min, data_max = labels_max) 104 | model.train() 105 | model.to(device) 106 | 107 | # 4. Train the network for n epochs using Adam optimizer and standard regression loss 108 | optimizer = Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 109 | criterion = nn.MSELoss() 110 | losses = [[], []] # train, test losses 111 | 112 | def get_test_loss(model, test_loader, epoch, plot_correlation = 1): 113 | if len(test_loader) == 0: 114 | return -1.0, -1 115 | model.eval() 116 | test_loss, dummy_test_loss = 0.0, 0.0 117 | test_preds, test_labels = [], [] 118 | with torch.no_grad(): 119 | for features, labels in test_loader: 120 | outputs = model(features) 121 | loss = criterion(outputs.squeeze(), labels) 122 | test_loss += loss.item() 123 | 124 | dummy_outputs = torch.ones_like(outputs) * labels.mean() 125 | dummy_loss = criterion(dummy_outputs.squeeze(), labels) 126 | dummy_test_loss += dummy_loss.item() 127 | 128 | if plot_correlation: 129 | test_preds.append(outputs.cpu().numpy()) 130 | test_labels.append(labels.cpu().numpy()) 131 | 132 | if plot_correlation and epoch % 10 == 0: 133 | test_preds = np.concatenate(test_preds, axis=0) 134 | test_labels = np.concatenate(test_labels, axis=0) 135 | plt.figure(figsize=(8, 8)) 136 | plt.scatter(test_labels, test_preds, alpha=0.1) 137 | plt.xlabel("True labels") 138 | plt.ylabel("Predicted labels") 139 | plt.plot([0, 1], [0, 1], color='r', linestyle='--') 140 | plt.title(f"Epoch {epoch}, r² = {r2_score(test_labels, test_preds):.3f}") 141 | plt.xlim(0, 1) 142 | plt.ylim(0, 1) 143 | plt.savefig("test_set_predictions.png") 144 | plt.close() 145 | 146 | test_loss /= len(test_loader) 147 | dummy_test_loss /= len(test_loader) 148 | model.train() 149 | return test_loss, dummy_test_loss 150 | 151 | def plot_losses(losses, y_axis_percentile_cutoff = 99.75, include_y_zero = 1): 152 | plt.figure(figsize=(16, 8)) 153 | plt.plot(losses[0], label="Train") 154 | plt.plot(losses[1], label="Test") 155 | plt.axhline(y=min(losses[1]), color='r', linestyle='--', label="Best test loss") 156 | all_losses = losses[0] + losses[1] 157 | if include_y_zero: 158 | plt.ylim(0, np.percentile(all_losses, y_axis_percentile_cutoff)) 159 | else: 160 | plt.ylim(np.min(all_losses), np.percentile(all_losses, y_axis_percentile_cutoff)) 161 | plt.xlabel("Epoch") 162 | plt.ylabel("MSE loss") 163 | plt.legend() 164 | plt.savefig("losses.png") 165 | plt.close() 166 | 167 | test_loss, dummy_test_loss = get_test_loss(model, test_loader, -1) 168 | print(f"\nBefore training, test mse-loss: {test_loss:.4f} (dummy: {dummy_test_loss:.4f})") 169 | 170 | for epoch in range(args.n_epochs): 171 | model.train() 172 | train_loss = 0.0 173 | for features, labels in train_loader: 174 | optimizer.zero_grad() 175 | outputs = model(features) 176 | loss = criterion(outputs.squeeze(), labels) 177 | loss.backward() 178 | optimizer.step() 179 | train_loss += loss.item() 180 | 181 | train_loss = train_loss / len(train_loader) 182 | test_loss, dummy_test_loss = get_test_loss(model, test_loader, epoch) 183 | losses[0].append(train_loss) 184 | losses[1].append(test_loss) 185 | if epoch % 2 == 0: 186 | test_str = f", test mse: {test_loss:.4f} (dummy: {dummy_test_loss:.4f})" if test_loss > 0 else "" 187 | print(f"Epoch {epoch+1} / {args.n_epochs}, train-mse: {train_loss:.4f}{test_str}") 188 | if epoch % (args.n_epochs // 10) == 0: 189 | plot_losses(losses) 190 | 191 | # Report: 192 | if test_loss > 0: 193 | print(f"---> Best test mse loss: {min(losses[1]):.4f} in epoch {np.argmin(losses[1])+1}") 194 | plot_losses(losses) 195 | 196 | if not args.dont_save: # Save the model 197 | model.eval() 198 | n_train = len(train_dataset) / 1000 199 | timestamp = pd.Timestamp.now().strftime("%Y-%m-%d_%H:%M:%S") 200 | model_save_name = f"{args.model_name}_{timestamp}_{n_train:.1f}k_imgs_{args.n_epochs}_epochs_{losses[1][-1]:.4f}_mse" 201 | os.makedirs("models", exist_ok=True) 202 | 203 | with open(f"models/{model_save_name}.pkl", "wb") as file: 204 | pickle.dump(model, file) 205 | 206 | print("Final model saved as:\n", f"models/{model_save_name}.pkl") 207 | 208 | if __name__ == "__main__": 209 | parser = argparse.ArgumentParser() 210 | 211 | # IO args: 212 | parser.add_argument('--train_data_dir', type=str, help='Root directory of the (optionally multiple) datasets') 213 | parser.add_argument('--train_data_names', type=str, nargs='+', help='Names of the dataset files to train on (space separated)') 214 | parser.add_argument('--model_name', type=str, default='regressor', help='Name of the model when saved to disk') 215 | parser.add_argument('--dont_save', action='store_true', help='dont save the model to disk') 216 | 217 | # Training args: 218 | parser.add_argument('--test_fraction', type=float, default=0.25, help='Fraction of the training data to use for testing') 219 | parser.add_argument('--n_epochs', type=int, default=80, help='Number of epochs to train for') 220 | parser.add_argument('--batch_size', type=int, default=32, help='Batch size for training') 221 | parser.add_argument('--lr', type=float, default=0.0005, help='Learning rate') 222 | parser.add_argument('--weight_decay', type=float, default=0.0005, help='Weight decay for the Adam optimizer (default: 0.001)') 223 | parser.add_argument('--dropout_prob', type=float, default=0.5, help='Dropout probability') 224 | parser.add_argument('--hidden_sizes', type=int, nargs='+', default=[128,128,64], help='Hidden sizes of the FC neural network') 225 | 226 | parser.add_argument('--print_network_layout', action='store_true', help='Print the network layout') 227 | parser.add_argument('--random_seed', type=int, default=42, help='Random seed for reproducibility') 228 | args = parser.parse_args() 229 | 230 | train(args) --------------------------------------------------------------------------------