├── .gitignore ├── LICENSE ├── README.md ├── demo_images ├── bad │ ├── 13776_subsurface_scattering_6.png │ ├── 4071_Monsters_Under_The_Bed_2.png │ ├── 5491_Bin_make_of_plastic_on_a_countertop_officesupplies_hdphotography_6.png │ ├── 58781_childrens_drawing_of_a_monster_5.png │ └── 8404_three_long_fingers_on_one_hand_6.png ├── good │ ├── 13623_Abyss_swallowed_the_moon_detailed_digital_space_art_trending_on_artstation_6.png │ ├── 17676_Her_flowers_dying_in_the_window_the_city_screams_like_heavens_gone_haunted_expressionist_painting_trending_on_artstation_6.png │ ├── 19727_the_thundering_joy_of_the_divine_observer_Alphonse_Mucha_looking_from_above_incredible_amazingly_beautiful_oil_painting_smooth_concept_art_4k_matte_painting_by_greg_rutkowski_thomas_kinkade_Ted_Nasmith_key_art_artstation_6.png │ ├── 27327_orange_tabby_cat_wearing_a_green_wig_by_scott_christian_sava_4.png │ └── 7939_a_sloth_sitting_like_a_human_drinking_tea_1940s_style_2.png ├── great │ ├── 15627_The_tryptamine_abyss_oil_painting_trending_on_artstation_7.png │ ├── 21259_a_massive_deactivated_robot_covered_in_moss_in_a_forest_clearing_concept_art_by_anton_fadeev_and_marc_simonetti_and_simon_stalenhag_trending_on_artstation_cgsociety_5.png │ ├── 53822_Elven_owl_portrait_wlop_artgerm_rossdraws__Ross_Tran_Bo_Chen_Rebecca_Oborn_Michael_Whelan_tom_bagshaw_Margarita_Kareva_Susan_Schroder_Sarah_Ann_Loreth_ArtStation_CGsociety_1.png │ ├── 53906_megastructure_appears_over_city_time_frozen._Hypperealist_symmetrical_dramatic_cinematic_composition_environment_scene_poster_illustration_by_John_Harris_6.png │ └── 55486_intaglio_vintage_stamp_portrait_of_Cyborg_Robot_Michael_Jordan_Greg_Rutkowski_James_Gilleard_Ishbel_Myerscough_Scott_Radke_Jean-Baptiste_Monge_jakub_rozalski_by_Belacqua_5.png └── mediocre │ ├── 10155_kangaroo_fight_with_conor_McGregor_in_MMA_style_8.png │ ├── 16152_Young_Arab_boy_in_yellow_robes_walking_into_a_desert_of_ashes.___concept_art_marc_simonetti_james_tissot_nekro_phil_hale._hd_3.png │ ├── 16381_emotional_support_puppet_on_trial_for_fraud_and_elder_abuse_live_courtroom_coverage_on_CSPAN_3.png │ ├── 18625_Grooving_with_the_eternal_now_detailed_digital_art_by_Jonathan_Solter_pixiv_artstation_5.png │ └── 8315_a_portrait_of_someone_in_a_car_5.png ├── models ├── sac_public_2022_06_29_vit_b_16_linear.pth ├── sac_public_2022_06_29_vit_b_32_linear.pth └── sac_public_2022_06_29_vit_l_14_linear.pth ├── rank_images.py ├── sacManualSort.png ├── sacModelSort.png ├── simulacra_compute_embeddings.py └── simulacra_fit_linear_model.py /.gitignore: -------------------------------------------------------------------------------- 1 | venv* 2 | __pycache__ 3 | .ipynb_checkpoints 4 | *.egg-info 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2022 Katherine Crowson and John David Pressman 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Introduction 2 | 3 | This is model fit and inference code for CLIP aesthetic regressions trained on 4 | [Simulacra Aesthetic Captions](https://github.com/JD-P/simulacra-aesthetic-captions). 5 | These remarkably simple models emulate human aesthetic judgment. They can be used 6 | in tasks such as dataset filtering to remove obviously poor quality images from 7 | the corpus before training. The following grids, one sorted by [John David 8 | Pressman](https://github.com/JD-P) and one sorted by the machine give some idea 9 | of the models capabilities: 10 | 11 | ### Manually Sorted Grid 12 | 13 | ![A human sorted grid of 20 images from worst to best, starting with the worst image in the 14 | top left and the best in the bottom right](https://github.com/crowsonkb/simulacra-aesthetic-models/raw/master/sacManualSort.png) 15 | 16 | ### Model Sorted Grid 17 | 18 | ![A machine sorted grid of 20 images from worst to best, starting with the worst image in the 19 | top left and the best in the bottom right](https://github.com/crowsonkb/simulacra-aesthetic-models/raw/master/sacModelSort.png) 20 | 21 | ## Installation 22 | 23 | Git clone this repository: 24 | 25 | ``` 26 | git clone https://github.com/crowsonkb/simulacra-aesthetic-models.git 27 | ``` 28 | 29 | Install pytorch if you don't already have it: 30 | 31 | ``` 32 | pip3 install torch==1.10.1+cu113 torchvision==0.11.2+cu113 torchaudio==0.10.1+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html 33 | ``` 34 | 35 | Then pip install our other dependencies: 36 | 37 | ``` 38 | pip3 install tqdm pillow torchvision sklearn numpy 39 | ``` 40 | 41 | If you don't already have it installed, you'll need to install CLIP: 42 | 43 | ``` 44 | git clone https://github.com/openai/CLIP.git 45 | cd CLIP 46 | pip3 install . 47 | cd .. 48 | ``` 49 | 50 | ## Usage 51 | 52 | The models are largely meant to be used as a library, i.e. you'll need to write 53 | specific code for your use case. But to get you started we've provided a sample 54 | script `rank_images.py` which finds all the `.jpg` or `.png` images in a directory 55 | tree and ranks the top N (default 50) with the aesthetic model: 56 | 57 | ``` 58 | python3 rank_images.py demo_images/ 59 | ``` 60 | -------------------------------------------------------------------------------- /demo_images/bad/13776_subsurface_scattering_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/crowsonkb/simulacra-aesthetic-models/701f7678b305acf8a0634350365a86beb95dd87d/demo_images/bad/13776_subsurface_scattering_6.png -------------------------------------------------------------------------------- /demo_images/bad/4071_Monsters_Under_The_Bed_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/crowsonkb/simulacra-aesthetic-models/701f7678b305acf8a0634350365a86beb95dd87d/demo_images/bad/4071_Monsters_Under_The_Bed_2.png -------------------------------------------------------------------------------- /demo_images/bad/5491_Bin_make_of_plastic_on_a_countertop_officesupplies_hdphotography_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/crowsonkb/simulacra-aesthetic-models/701f7678b305acf8a0634350365a86beb95dd87d/demo_images/bad/5491_Bin_make_of_plastic_on_a_countertop_officesupplies_hdphotography_6.png -------------------------------------------------------------------------------- /demo_images/bad/58781_childrens_drawing_of_a_monster_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/crowsonkb/simulacra-aesthetic-models/701f7678b305acf8a0634350365a86beb95dd87d/demo_images/bad/58781_childrens_drawing_of_a_monster_5.png -------------------------------------------------------------------------------- /demo_images/bad/8404_three_long_fingers_on_one_hand_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/crowsonkb/simulacra-aesthetic-models/701f7678b305acf8a0634350365a86beb95dd87d/demo_images/bad/8404_three_long_fingers_on_one_hand_6.png -------------------------------------------------------------------------------- /demo_images/good/13623_Abyss_swallowed_the_moon_detailed_digital_space_art_trending_on_artstation_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/crowsonkb/simulacra-aesthetic-models/701f7678b305acf8a0634350365a86beb95dd87d/demo_images/good/13623_Abyss_swallowed_the_moon_detailed_digital_space_art_trending_on_artstation_6.png -------------------------------------------------------------------------------- /demo_images/good/17676_Her_flowers_dying_in_the_window_the_city_screams_like_heavens_gone_haunted_expressionist_painting_trending_on_artstation_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/crowsonkb/simulacra-aesthetic-models/701f7678b305acf8a0634350365a86beb95dd87d/demo_images/good/17676_Her_flowers_dying_in_the_window_the_city_screams_like_heavens_gone_haunted_expressionist_painting_trending_on_artstation_6.png -------------------------------------------------------------------------------- /demo_images/good/19727_the_thundering_joy_of_the_divine_observer_Alphonse_Mucha_looking_from_above_incredible_amazingly_beautiful_oil_painting_smooth_concept_art_4k_matte_painting_by_greg_rutkowski_thomas_kinkade_Ted_Nasmith_key_art_artstation_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/crowsonkb/simulacra-aesthetic-models/701f7678b305acf8a0634350365a86beb95dd87d/demo_images/good/19727_the_thundering_joy_of_the_divine_observer_Alphonse_Mucha_looking_from_above_incredible_amazingly_beautiful_oil_painting_smooth_concept_art_4k_matte_painting_by_greg_rutkowski_thomas_kinkade_Ted_Nasmith_key_art_artstation_6.png -------------------------------------------------------------------------------- /demo_images/good/27327_orange_tabby_cat_wearing_a_green_wig_by_scott_christian_sava_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/crowsonkb/simulacra-aesthetic-models/701f7678b305acf8a0634350365a86beb95dd87d/demo_images/good/27327_orange_tabby_cat_wearing_a_green_wig_by_scott_christian_sava_4.png -------------------------------------------------------------------------------- /demo_images/good/7939_a_sloth_sitting_like_a_human_drinking_tea_1940s_style_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/crowsonkb/simulacra-aesthetic-models/701f7678b305acf8a0634350365a86beb95dd87d/demo_images/good/7939_a_sloth_sitting_like_a_human_drinking_tea_1940s_style_2.png -------------------------------------------------------------------------------- /demo_images/great/15627_The_tryptamine_abyss_oil_painting_trending_on_artstation_7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/crowsonkb/simulacra-aesthetic-models/701f7678b305acf8a0634350365a86beb95dd87d/demo_images/great/15627_The_tryptamine_abyss_oil_painting_trending_on_artstation_7.png -------------------------------------------------------------------------------- /demo_images/great/21259_a_massive_deactivated_robot_covered_in_moss_in_a_forest_clearing_concept_art_by_anton_fadeev_and_marc_simonetti_and_simon_stalenhag_trending_on_artstation_cgsociety_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/crowsonkb/simulacra-aesthetic-models/701f7678b305acf8a0634350365a86beb95dd87d/demo_images/great/21259_a_massive_deactivated_robot_covered_in_moss_in_a_forest_clearing_concept_art_by_anton_fadeev_and_marc_simonetti_and_simon_stalenhag_trending_on_artstation_cgsociety_5.png -------------------------------------------------------------------------------- /demo_images/great/53822_Elven_owl_portrait_wlop_artgerm_rossdraws__Ross_Tran_Bo_Chen_Rebecca_Oborn_Michael_Whelan_tom_bagshaw_Margarita_Kareva_Susan_Schroder_Sarah_Ann_Loreth_ArtStation_CGsociety_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/crowsonkb/simulacra-aesthetic-models/701f7678b305acf8a0634350365a86beb95dd87d/demo_images/great/53822_Elven_owl_portrait_wlop_artgerm_rossdraws__Ross_Tran_Bo_Chen_Rebecca_Oborn_Michael_Whelan_tom_bagshaw_Margarita_Kareva_Susan_Schroder_Sarah_Ann_Loreth_ArtStation_CGsociety_1.png -------------------------------------------------------------------------------- /demo_images/great/53906_megastructure_appears_over_city_time_frozen._Hypperealist_symmetrical_dramatic_cinematic_composition_environment_scene_poster_illustration_by_John_Harris_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/crowsonkb/simulacra-aesthetic-models/701f7678b305acf8a0634350365a86beb95dd87d/demo_images/great/53906_megastructure_appears_over_city_time_frozen._Hypperealist_symmetrical_dramatic_cinematic_composition_environment_scene_poster_illustration_by_John_Harris_6.png -------------------------------------------------------------------------------- /demo_images/great/55486_intaglio_vintage_stamp_portrait_of_Cyborg_Robot_Michael_Jordan_Greg_Rutkowski_James_Gilleard_Ishbel_Myerscough_Scott_Radke_Jean-Baptiste_Monge_jakub_rozalski_by_Belacqua_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/crowsonkb/simulacra-aesthetic-models/701f7678b305acf8a0634350365a86beb95dd87d/demo_images/great/55486_intaglio_vintage_stamp_portrait_of_Cyborg_Robot_Michael_Jordan_Greg_Rutkowski_James_Gilleard_Ishbel_Myerscough_Scott_Radke_Jean-Baptiste_Monge_jakub_rozalski_by_Belacqua_5.png -------------------------------------------------------------------------------- /demo_images/mediocre/10155_kangaroo_fight_with_conor_McGregor_in_MMA_style_8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/crowsonkb/simulacra-aesthetic-models/701f7678b305acf8a0634350365a86beb95dd87d/demo_images/mediocre/10155_kangaroo_fight_with_conor_McGregor_in_MMA_style_8.png -------------------------------------------------------------------------------- /demo_images/mediocre/16152_Young_Arab_boy_in_yellow_robes_walking_into_a_desert_of_ashes.___concept_art_marc_simonetti_james_tissot_nekro_phil_hale._hd_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/crowsonkb/simulacra-aesthetic-models/701f7678b305acf8a0634350365a86beb95dd87d/demo_images/mediocre/16152_Young_Arab_boy_in_yellow_robes_walking_into_a_desert_of_ashes.___concept_art_marc_simonetti_james_tissot_nekro_phil_hale._hd_3.png -------------------------------------------------------------------------------- /demo_images/mediocre/16381_emotional_support_puppet_on_trial_for_fraud_and_elder_abuse_live_courtroom_coverage_on_CSPAN_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/crowsonkb/simulacra-aesthetic-models/701f7678b305acf8a0634350365a86beb95dd87d/demo_images/mediocre/16381_emotional_support_puppet_on_trial_for_fraud_and_elder_abuse_live_courtroom_coverage_on_CSPAN_3.png -------------------------------------------------------------------------------- /demo_images/mediocre/18625_Grooving_with_the_eternal_now_detailed_digital_art_by_Jonathan_Solter_pixiv_artstation_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/crowsonkb/simulacra-aesthetic-models/701f7678b305acf8a0634350365a86beb95dd87d/demo_images/mediocre/18625_Grooving_with_the_eternal_now_detailed_digital_art_by_Jonathan_Solter_pixiv_artstation_5.png -------------------------------------------------------------------------------- /demo_images/mediocre/8315_a_portrait_of_someone_in_a_car_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/crowsonkb/simulacra-aesthetic-models/701f7678b305acf8a0634350365a86beb95dd87d/demo_images/mediocre/8315_a_portrait_of_someone_in_a_car_5.png -------------------------------------------------------------------------------- /models/sac_public_2022_06_29_vit_b_16_linear.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/crowsonkb/simulacra-aesthetic-models/701f7678b305acf8a0634350365a86beb95dd87d/models/sac_public_2022_06_29_vit_b_16_linear.pth -------------------------------------------------------------------------------- /models/sac_public_2022_06_29_vit_b_32_linear.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/crowsonkb/simulacra-aesthetic-models/701f7678b305acf8a0634350365a86beb95dd87d/models/sac_public_2022_06_29_vit_b_32_linear.pth -------------------------------------------------------------------------------- /models/sac_public_2022_06_29_vit_l_14_linear.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/crowsonkb/simulacra-aesthetic-models/701f7678b305acf8a0634350365a86beb95dd87d/models/sac_public_2022_06_29_vit_l_14_linear.pth -------------------------------------------------------------------------------- /rank_images.py: -------------------------------------------------------------------------------- 1 | import os 2 | from argparse import ArgumentParser 3 | from tqdm import tqdm 4 | from PIL import Image 5 | from torch.nn import functional as F 6 | from torchvision import transforms 7 | from torchvision.transforms import functional as TF 8 | import torch 9 | from simulacra_fit_linear_model import AestheticMeanPredictionLinearModel 10 | from CLIP import clip 11 | 12 | parser = ArgumentParser() 13 | parser.add_argument("directory") 14 | parser.add_argument("-t", "--top-n", default=50) 15 | args = parser.parse_args() 16 | 17 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 18 | 19 | clip_model_name = 'ViT-B/16' 20 | clip_model = clip.load(clip_model_name, jit=False, device=device)[0] 21 | clip_model.eval().requires_grad_(False) 22 | 23 | normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], 24 | std=[0.26862954, 0.26130258, 0.27577711]) 25 | 26 | # 512 is embed dimension for ViT-B/16 CLIP 27 | model = AestheticMeanPredictionLinearModel(512) 28 | model.load_state_dict( 29 | torch.load("models/sac_public_2022_06_29_vit_b_16_linear.pth") 30 | ) 31 | model = model.to(device) 32 | 33 | def get_filepaths(parentpath, filepaths): 34 | paths = [] 35 | for path in filepaths: 36 | try: 37 | new_parent = os.path.join(parentpath, path) 38 | paths += get_filepaths(new_parent, os.listdir(new_parent)) 39 | except NotADirectoryError: 40 | paths.append(os.path.join(parentpath, path)) 41 | return paths 42 | 43 | filepaths = get_filepaths(args.directory, os.listdir(args.directory)) 44 | scores = [] 45 | for path in tqdm(filepaths): 46 | # This is obviously a flawed way to check for an image but this is just 47 | # a demo script anyway. 48 | if path[-4:] not in (".png", ".jpg"): 49 | continue 50 | img = Image.open(path).convert('RGB') 51 | img = TF.resize(img, 224, transforms.InterpolationMode.LANCZOS) 52 | img = TF.center_crop(img, (224,224)) 53 | img = TF.to_tensor(img).to(device) 54 | img = normalize(img) 55 | clip_image_embed = F.normalize( 56 | clip_model.encode_image(img[None, ...]).float(), 57 | dim=-1) 58 | score = model(clip_image_embed) 59 | if len(scores) < args.top_n: 60 | scores.append((score.item(),path)) 61 | scores.sort() 62 | else: 63 | if scores[0][0] < score: 64 | scores.append((score.item(),path)) 65 | scores.sort(key=lambda x: x[0]) 66 | scores = scores[1:] 67 | 68 | for score, path in scores: 69 | print(f"{score}: {path}") 70 | -------------------------------------------------------------------------------- /sacManualSort.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/crowsonkb/simulacra-aesthetic-models/701f7678b305acf8a0634350365a86beb95dd87d/sacManualSort.png -------------------------------------------------------------------------------- /sacModelSort.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/crowsonkb/simulacra-aesthetic-models/701f7678b305acf8a0634350365a86beb95dd87d/sacModelSort.png -------------------------------------------------------------------------------- /simulacra_compute_embeddings.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """Precomputes CLIP embeddings for Simulacra Aesthetic Captions.""" 4 | 5 | import argparse 6 | import os 7 | from pathlib import Path 8 | import sqlite3 9 | 10 | from PIL import Image 11 | 12 | import torch 13 | from torch import multiprocessing as mp 14 | from torch.utils import data 15 | import torchvision.transforms as transforms 16 | from tqdm import tqdm 17 | 18 | from CLIP import clip 19 | 20 | 21 | class SimulacraDataset(data.Dataset): 22 | """Simulacra dataset 23 | Args: 24 | images_dir: directory 25 | transform: preprocessing and augmentation of the training images 26 | """ 27 | 28 | def __init__(self, images_dir, db, transform=None): 29 | self.images_dir = Path(images_dir) 30 | self.transform = transform 31 | self.conn = sqlite3.connect(db) 32 | self.ratings = [] 33 | for row in self.conn.execute('SELECT generations.id, images.idx, paths.path, AVG(ratings.rating) FROM images JOIN generations ON images.gid=generations.id JOIN ratings ON images.id=ratings.iid JOIN paths ON images.id=paths.iid GROUP BY images.id'): 34 | self.ratings.append(row) 35 | 36 | def __len__(self): 37 | return len(self.ratings) 38 | 39 | def __getitem__(self, key): 40 | gid, idx, filename, rating = self.ratings[key] 41 | image = Image.open(self.images_dir / filename).convert('RGB') 42 | if self.transform: 43 | image = self.transform(image) 44 | return image, torch.tensor(rating) 45 | 46 | 47 | def main(): 48 | p = argparse.ArgumentParser(description=__doc__) 49 | p.add_argument('--batch-size', '-bs', type=int, default=10, 50 | help='the CLIP model') 51 | p.add_argument('--clip-model', type=str, default='ViT-B/16', 52 | help='the CLIP model') 53 | p.add_argument('--db', type=str, required=True, 54 | help='the database location') 55 | p.add_argument('--device', type=str, 56 | help='the device to use') 57 | p.add_argument('--images-dir', type=str, required=True, 58 | help='the dataset images directory') 59 | p.add_argument('--num-workers', type=int, default=8, 60 | help='the number of data loader workers') 61 | p.add_argument('--output', type=str, required=True, 62 | help='the output file') 63 | p.add_argument('--start-method', type=str, default='spawn', 64 | choices=['fork', 'forkserver', 'spawn'], 65 | help='the multiprocessing start method') 66 | args = p.parse_args() 67 | 68 | mp.set_start_method(args.start_method) 69 | if args.device: 70 | device = torch.device(device) 71 | else: 72 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 73 | print('Using device:', device) 74 | 75 | clip_model, clip_tf = clip.load(args.clip_model, device=device, jit=False) 76 | clip_model = clip_model.eval().requires_grad_(False) 77 | 78 | dataset = SimulacraDataset(args.images_dir, args.db, transform=clip_tf) 79 | loader = data.DataLoader(dataset, args.batch_size, num_workers=args.num_workers) 80 | 81 | embeds, ratings = [], [] 82 | 83 | for batch in tqdm(loader): 84 | images_batch, ratings_batch = batch 85 | embeds.append(clip_model.encode_image(images_batch.to(device)).cpu()) 86 | ratings.append(ratings_batch.clone()) 87 | 88 | obj = {'clip_model': args.clip_model, 89 | 'embeds': torch.cat(embeds), 90 | 'ratings': torch.cat(ratings)} 91 | 92 | torch.save(obj, args.output) 93 | 94 | 95 | if __name__ == '__main__': 96 | main() 97 | -------------------------------------------------------------------------------- /simulacra_fit_linear_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """Fits a linear aesthetic model to precomputed CLIP embeddings.""" 4 | 5 | import argparse 6 | 7 | import numpy as np 8 | from sklearn.linear_model import Ridge 9 | from sklearn.model_selection import train_test_split 10 | import torch 11 | from torch import nn 12 | from torch.nn import functional as F 13 | 14 | 15 | class AestheticMeanPredictionLinearModel(nn.Module): 16 | def __init__(self, feats_in): 17 | super().__init__() 18 | self.linear = nn.Linear(feats_in, 1) 19 | 20 | def forward(self, input): 21 | x = F.normalize(input, dim=-1) * input.shape[-1] ** 0.5 22 | return self.linear(x) 23 | 24 | 25 | def main(): 26 | p = argparse.ArgumentParser(description=__doc__) 27 | p.add_argument('input', type=str, help='the input feature vectors') 28 | p.add_argument('output', type=str, help='the output model') 29 | p.add_argument('--val-size', type=float, default=0.1, help='the validation set size') 30 | p.add_argument('--seed', type=int, default=0, help='the random seed') 31 | args = p.parse_args() 32 | 33 | train_set = torch.load(args.input, map_location='cpu') 34 | X = F.normalize(train_set['embeds'].float(), dim=-1).numpy() 35 | X *= X.shape[-1] ** 0.5 36 | y = train_set['ratings'].numpy() 37 | X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=args.val_size, random_state=args.seed) 38 | regression = Ridge() 39 | regression.fit(X_train, y_train) 40 | score_train = regression.score(X_train, y_train) 41 | score_val = regression.score(X_val, y_val) 42 | print(f'Score on train: {score_train:g}') 43 | print(f'Score on val: {score_val:g}') 44 | model = AestheticMeanPredictionLinearModel(X_train.shape[1]) 45 | with torch.no_grad(): 46 | model.linear.weight.copy_(torch.tensor(regression.coef_)) 47 | model.linear.bias.copy_(torch.tensor(regression.intercept_)) 48 | torch.save(model.state_dict(), args.output) 49 | 50 | 51 | if __name__ == '__main__': 52 | main() 53 | --------------------------------------------------------------------------------