├── .gitignore ├── README.md ├── assets ├── simple_prompts.txt ├── unknown_image.jpg └── workflow_sdxl_turbo.json ├── environment.yml ├── helper.py └── minimal_usage_latent_clip.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | cache/ 2 | __pycache__/ 3 | latent_clip*/ 4 | Latent_ReNO*/ 5 | outputs/ 6 | models/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Supplementary Directory 3 | 4 | This repository contains resources and code for utilizing **Latent CLIP** for zero-shot prediction and reward-based noise optimization. 5 | 6 | --- 7 | 8 | ## 📋 Installation 9 | 10 | To set up the environment, use the provided `environment.yml` file: 11 | 12 | ```bash 13 | conda env create -f environment.yml 14 | conda activate latentclipenv 15 | ``` 16 | 17 | --- 18 | 19 | ## 🚀 Usage 20 | 21 | The starting point for understanding and utilizing **Latent CLIP** is the Jupyter Notebook: 22 | 23 | **`minimal_usage_latent_clip.ipynb`** 24 | 25 | This notebook demonstrates: 26 | - **Zero-shot prediction** using Latent CLIP. 27 | - **Reward-based noise optimization** using Latent CLIP-based rewards . 28 | 29 | --- 30 | 31 | ### 📌 Using the ComfyUI Workflow 32 | The file **`workflow_sdxl_turbo.json`** is a workflow designed for use with **ComfyUI**. 33 | 34 | **Steps to use the workflow:** 35 | 1. **Clone ComfyUI** from GitHub: 36 | ```bash 37 | git clone https://github.com/comfyanonymous/ComfyUI.git 38 | cd ComfyUI 39 | ``` 40 | 41 | 2. **Run ComfyUI**: 42 | ```bash 43 | python main.py 44 | ``` 45 | 46 | 3. **Load the Workflow**: 47 | - In the ComfyUI interface, press **`Ctrl + O`** (or click "Load"). 48 | - Select the file **`workflow_sdxl_turbo.json`** from the `assets/` folder. 49 | 50 | 4. For more information on ComfyUI, visit: 51 | ➡️ [ComfyUI GitHub Repository](https://github.com/comfyanonymous/ComfyUI) 52 | 53 | --- 54 | 55 | ## 📂 Directory Structure 56 | 57 | ``` 58 | supplementary/ 59 | │ 60 | ├── assets/ # Additional resources 61 | │ └── workflow_sdxl_turbo.json # ComfyUI workflow file (https://github.com/comfyanonymous/ComfyUI) 62 | │ 63 | ├── Latent_ReNO/ # Implementation for reward-based noise optimization 64 | ├── environment.yml # Conda environment setup file 65 | ├── helper.py # Utility functions for supporting the notebook 66 | └── minimal_usage_latent_clip.ipynb # Main notebook for starting with Latent CLIP 67 | ``` 68 | -------------------------------------------------------------------------------- /assets/simple_prompts.txt: -------------------------------------------------------------------------------- 1 | a car 2 | an apple 3 | an eagle 4 | a tree 5 | a house 6 | a book 7 | a cup 8 | a table 9 | a chair 10 | a lamp 11 | a mountain 12 | a river 13 | a beach 14 | a boat 15 | a fish 16 | a bird 17 | a dog 18 | a cat 19 | a horse 20 | a cow 21 | a bus 22 | a train 23 | a plane 24 | a truck 25 | a bicycle 26 | a road 27 | a bridge 28 | a door 29 | a window 30 | a clock 31 | a phone 32 | a laptop 33 | a pen 34 | a pencil 35 | a notebook 36 | a backpack 37 | a hat 38 | a shoe 39 | a sock 40 | a jacket 41 | a treehouse 42 | a fireplace 43 | a television 44 | a fridge 45 | a washing machine 46 | a mirror 47 | a spoon 48 | a fork 49 | a knife 50 | a plate 51 | a ball 52 | a doll 53 | a toy car 54 | a teddy bear 55 | a robot 56 | a kite 57 | a drum 58 | a guitar 59 | a violin 60 | a piano 61 | a road sign 62 | a mailbox 63 | a stoplight 64 | a park bench 65 | a street lamp 66 | a tent 67 | a campfire 68 | a sleeping bag 69 | a suitcase 70 | a globe 71 | a firetruck 72 | a police car 73 | an ambulance 74 | a tractor 75 | a windmill 76 | a sunflower 77 | a rose 78 | a lily 79 | a tree branch 80 | a pinecone 81 | a teapot 82 | a frying pan 83 | a cutting board 84 | a whisk 85 | a rolling pin 86 | a hamburger 87 | a pizza 88 | a hotdog 89 | a sandwich 90 | a donut 91 | a glass of water 92 | a milk carton 93 | a juice box 94 | a cereal bowl 95 | a cake 96 | a hammer 97 | a screwdriver 98 | a wrench 99 | a saw 100 | a drill -------------------------------------------------------------------------------- /assets/unknown_image.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsonBackup/Latent-CLIP-Demo/7b9a4eb0281f58790b9713c804e42e677f7769ea/assets/unknown_image.jpg -------------------------------------------------------------------------------- /assets/workflow_sdxl_turbo.json: -------------------------------------------------------------------------------- 1 | { 2 | "3": { 3 | "inputs": { 4 | "seed": 342279256743858, 5 | "steps": 2, 6 | "cfg": 1.1, 7 | "sampler_name": "euler_ancestral", 8 | "scheduler": "normal", 9 | "denoise": 0.66, 10 | "model": [ 11 | "4", 12 | 0 13 | ], 14 | "positive": [ 15 | "6", 16 | 0 17 | ], 18 | "negative": [ 19 | "7", 20 | 0 21 | ], 22 | "latent_image": [ 23 | "10", 24 | 0 25 | ] 26 | }, 27 | "class_type": "KSampler", 28 | "_meta": { 29 | "title": "KSampler" 30 | } 31 | }, 32 | "4": { 33 | "inputs": { 34 | "ckpt_name": "sd_xl_turbo_1.0_fp16.safetensors" 35 | }, 36 | "class_type": "CheckpointLoaderSimple", 37 | "_meta": { 38 | "title": "Load Checkpoint" 39 | } 40 | }, 41 | "6": { 42 | "inputs": { 43 | "text": "a man with long hair and a pirate beard", 44 | "clip": [ 45 | "4", 46 | 1 47 | ] 48 | }, 49 | "class_type": "CLIPTextEncode", 50 | "_meta": { 51 | "title": "CLIP Text Encode (Prompt)" 52 | } 53 | }, 54 | "7": { 55 | "inputs": { 56 | "text": "text, watermark", 57 | "clip": [ 58 | "4", 59 | 1 60 | ] 61 | }, 62 | "class_type": "CLIPTextEncode", 63 | "_meta": { 64 | "title": "CLIP Text Encode (Prompt)" 65 | } 66 | }, 67 | "8": { 68 | "inputs": { 69 | "samples": [ 70 | "3", 71 | 0 72 | ], 73 | "vae": [ 74 | "4", 75 | 2 76 | ] 77 | }, 78 | "class_type": "VAEDecode", 79 | "_meta": { 80 | "title": "VAE Decode" 81 | } 82 | }, 83 | "9": { 84 | "inputs": { 85 | "filename_prefix": "ComfyUI", 86 | "images": [ 87 | "8", 88 | 0 89 | ] 90 | }, 91 | "class_type": "SaveImage", 92 | "_meta": { 93 | "title": "Save Image" 94 | } 95 | }, 96 | "10": { 97 | "inputs": { 98 | "pixels": [ 99 | "12", 100 | 0 101 | ], 102 | "vae": [ 103 | "4", 104 | 2 105 | ] 106 | }, 107 | "class_type": "VAEEncode", 108 | "_meta": { 109 | "title": "VAE Encode" 110 | } 111 | }, 112 | "11": { 113 | "inputs": { 114 | "image": "me (1).png", 115 | "upload": "image" 116 | }, 117 | "class_type": "LoadImage", 118 | "_meta": { 119 | "title": "Load Image" 120 | } 121 | }, 122 | "12": { 123 | "inputs": { 124 | "upscale_method": "nearest-exact", 125 | "megapixels": 0.25, 126 | "image": [ 127 | "14", 128 | 0 129 | ] 130 | }, 131 | "class_type": "ImageScaleToTotalPixels", 132 | "_meta": { 133 | "title": "ImageScaleToTotalPixels" 134 | } 135 | }, 136 | "13": { 137 | "inputs": { 138 | "filename_prefix": "latents/ComfyUI", 139 | "samples": [ 140 | "3", 141 | 0 142 | ] 143 | }, 144 | "class_type": "SaveLatent", 145 | "_meta": { 146 | "title": "SaveLatent" 147 | } 148 | }, 149 | "14": { 150 | "inputs": { 151 | "upscale_method": "nearest-exact", 152 | "width": 512, 153 | "height": 512, 154 | "crop": "center", 155 | "image": [ 156 | "11", 157 | 0 158 | ] 159 | }, 160 | "class_type": "ImageScale", 161 | "_meta": { 162 | "title": "Upscale Image" 163 | } 164 | } 165 | } -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: latentclipenv 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - conda-forge 6 | dependencies: 7 | - _libgcc_mutex=0.1=conda_forge 8 | - _openmp_mutex=4.5=2_kmp_llvm 9 | - aom=3.9.1=hac33072_0 10 | - appdirs=1.4.4=pyh9f0ad1d_0 11 | - asttokens=2.4.1=pyhd8ed1ab_0 12 | - blas=1.0=mkl 13 | - brotli-python=1.1.0=py311hb755f60_1 14 | - bzip2=1.0.8=hd590300_5 15 | - ca-certificates=2025.1.31=hbcca054_0 16 | - cairo=1.18.0=hbb29018_2 17 | - certifi=2025.1.31=pyhd8ed1ab_0 18 | - cffi=1.16.0=py311hb3a22ac_0 19 | - charset-normalizer=3.3.2=pyhd8ed1ab_0 20 | - click=8.1.7=unix_pyh707e725_0 21 | - colorama=0.4.6=pyhd8ed1ab_0 22 | - comm=0.2.2=pyhd8ed1ab_0 23 | - cuda-cudart=12.1.105=0 24 | - cuda-cupti=12.1.105=0 25 | - cuda-libraries=12.1.0=0 26 | - cuda-nvrtc=12.1.105=0 27 | - cuda-nvtx=12.1.105=0 28 | - cuda-opencl=12.5.39=0 29 | - cuda-runtime=12.1.0=0 30 | - cuda-version=12.5=3 31 | - dav1d=1.2.1=hd590300_0 32 | - debugpy=1.8.8=py311hfdbb021_0 33 | - decorator=5.1.1=pyhd8ed1ab_0 34 | - docker-pycreds=0.4.0=py_0 35 | - exceptiongroup=1.2.2=pyhd8ed1ab_0 36 | - executing=2.1.0=pyhd8ed1ab_0 37 | - expat=2.6.2=h59595ed_0 38 | - ffmpeg=7.0.1=gpl_h9be9148_104 39 | - filelock=3.15.4=pyhd8ed1ab_0 40 | - font-ttf-dejavu-sans-mono=2.37=hab24e00_0 41 | - font-ttf-inconsolata=3.000=h77eed37_0 42 | - font-ttf-source-code-pro=2.038=h77eed37_0 43 | - font-ttf-ubuntu=0.83=h77eed37_2 44 | - fontconfig=2.14.2=h14ed4e7_0 45 | - fonts-conda-ecosystem=1=0 46 | - fonts-conda-forge=1=0 47 | - freetype=2.12.1=h267a509_2 48 | - fribidi=1.0.10=h36c2ea0_0 49 | - gettext=0.22.5=h59595ed_2 50 | - gettext-tools=0.22.5=h59595ed_2 51 | - gitdb=4.0.11=pyhd8ed1ab_0 52 | - gitpython=3.1.43=pyhd8ed1ab_0 53 | - gmp=6.3.0=hac33072_2 54 | - gmpy2=2.1.5=py311hc4f1f91_1 55 | - gnutls=3.7.9=hb077bed_0 56 | - graphite2=1.3.13=h59595ed_1003 57 | - h2=4.1.0=pyhd8ed1ab_0 58 | - harfbuzz=9.0.0=hfac3d4d_0 59 | - hpack=4.0.0=pyh9f0ad1d_0 60 | - hyperframe=6.0.1=pyhd8ed1ab_0 61 | - icu=73.2=h59595ed_0 62 | - idna=3.7=pyhd8ed1ab_0 63 | - imageio=2.34.2=pyh12aca89_0 64 | - ipykernel=6.29.5=pyh3099207_0 65 | - ipython=8.29.0=pyh707e725_0 66 | - ipywidgets=8.1.5=pyhd8ed1ab_1 67 | - jedi=0.19.2=pyhff2d567_0 68 | - jinja2=3.1.4=pyhd8ed1ab_0 69 | - joblib=1.4.2=pyhd8ed1ab_0 70 | - jupyter_client=8.6.3=pyhd8ed1ab_0 71 | - jupyter_core=5.7.2=pyh31011fe_1 72 | - jupyterlab_widgets=3.0.13=pyhd8ed1ab_1 73 | - keyutils=1.6.1=h166bdaf_0 74 | - krb5=1.21.3=h659f571_0 75 | - lame=3.100=h166bdaf_1003 76 | - lcms2=2.16=hb7c19ff_0 77 | - ld_impl_linux-64=2.40=hf3520f5_7 78 | - lerc=4.0.0=h27087fc_0 79 | - libabseil=20240116.2=cxx17_h59595ed_0 80 | - libasprintf=0.22.5=h661eb56_2 81 | - libasprintf-devel=0.22.5=h661eb56_2 82 | - libass=0.17.1=h39113c1_2 83 | - libblas=3.9.0=16_linux64_mkl 84 | - libcblas=3.9.0=16_linux64_mkl 85 | - libcublas=12.1.0.26=0 86 | - libcufft=11.0.2.4=0 87 | - libcufile=1.10.0.4=0 88 | - libcurand=10.3.6.39=0 89 | - libcusolver=11.4.4.55=0 90 | - libcusparse=12.0.2.55=0 91 | - libdeflate=1.20=hd590300_0 92 | - libdrm=2.4.122=h4ab18f5_0 93 | - libedit=3.1.20191231=he28a2e2_2 94 | - libexpat=2.6.2=h59595ed_0 95 | - libffi=3.4.2=h7f98852_5 96 | - libgcc=14.2.0=h77fa898_1 97 | - libgcc-ng=14.2.0=h69a702a_1 98 | - libgettextpo=0.22.5=h59595ed_2 99 | - libgettextpo-devel=0.22.5=h59595ed_2 100 | - libgfortran-ng=14.1.0=h69a702a_0 101 | - libgfortran5=14.1.0=hc5f4f2c_0 102 | - libglib=2.80.2=h8a4344b_1 103 | - libhwloc=2.11.0=default_h5622ce7_1000 104 | - libiconv=1.17=hd590300_2 105 | - libidn2=2.3.7=hd590300_0 106 | - libjpeg-turbo=3.0.0=hd590300_1 107 | - liblapack=3.9.0=16_linux64_mkl 108 | - libnpp=12.0.2.50=0 109 | - libnsl=2.0.1=hd590300_0 110 | - libnvjitlink=12.1.105=0 111 | - libnvjpeg=12.1.1.14=0 112 | - libopenvino=2024.2.0=h2da1b83_1 113 | - libopenvino-auto-batch-plugin=2024.2.0=hb045406_1 114 | - libopenvino-auto-plugin=2024.2.0=hb045406_1 115 | - libopenvino-hetero-plugin=2024.2.0=h5c03a75_1 116 | - libopenvino-intel-cpu-plugin=2024.2.0=h2da1b83_1 117 | - libopenvino-intel-gpu-plugin=2024.2.0=h2da1b83_1 118 | - libopenvino-intel-npu-plugin=2024.2.0=he02047a_1 119 | - libopenvino-ir-frontend=2024.2.0=h5c03a75_1 120 | - libopenvino-onnx-frontend=2024.2.0=h07e8aee_1 121 | - libopenvino-paddle-frontend=2024.2.0=h07e8aee_1 122 | - libopenvino-pytorch-frontend=2024.2.0=he02047a_1 123 | - libopenvino-tensorflow-frontend=2024.2.0=h39126c6_1 124 | - libopenvino-tensorflow-lite-frontend=2024.2.0=he02047a_1 125 | - libopus=1.3.1=h7f98852_1 126 | - libpciaccess=0.18=hd590300_0 127 | - libpng=1.6.43=h2797004_0 128 | - libprotobuf=4.25.3=h08a7969_0 129 | - libsodium=1.0.20=h4ab18f5_0 130 | - libsqlite=3.46.0=hde9e2c9_0 131 | - libstdcxx=14.2.0=hc0a3c3a_1 132 | - libstdcxx-ng=14.2.0=h4852527_1 133 | - libtasn1=4.19.0=h166bdaf_0 134 | - libtiff=4.6.0=h1dd3fc0_3 135 | - libunistring=0.9.10=h7f98852_0 136 | - libuuid=2.38.1=h0b41bf4_0 137 | - libva=2.22.0=hb711507_0 138 | - libvpx=1.14.1=hac33072_0 139 | - libwebp-base=1.4.0=hd590300_0 140 | - libxcb=1.16=hd590300_0 141 | - libxcrypt=4.4.36=hd590300_1 142 | - libxml2=2.12.7=hc051c1a_1 143 | - libzlib=1.3.1=h4ab18f5_1 144 | - lightning-utilities=0.11.3.post0=pyhd8ed1ab_0 145 | - llvm-openmp=15.0.7=h0cdce71_0 146 | - markupsafe=2.1.5=py311h459d7ec_0 147 | - matplotlib-inline=0.1.7=pyhd8ed1ab_0 148 | - mkl=2022.2.1=h84fe81f_16997 149 | - mpc=1.3.1=hfe3b2da_0 150 | - mpfr=4.2.1=h9458935_1 151 | - mpmath=1.3.0=pyhd8ed1ab_0 152 | - ncurses=6.5=h59595ed_0 153 | - nest-asyncio=1.6.0=pyhd8ed1ab_0 154 | - nettle=3.9.1=h7ab15ed_0 155 | - networkx=3.3=pyhd8ed1ab_1 156 | - ocl-icd=2.3.2=hd590300_1 157 | - openh264=2.4.1=h59595ed_0 158 | - openjpeg=2.5.2=h488ebb8_0 159 | - openssl=3.4.1=h7b32b05_0 160 | - p11-kit=0.24.1=hc5aa10d_0 161 | - packaging=24.1=pyhd8ed1ab_0 162 | - parso=0.8.4=pyhd8ed1ab_0 163 | - pathtools=0.1.2=py_1 164 | - pcre2=10.44=h0f59acf_0 165 | - pexpect=4.9.0=pyhd8ed1ab_0 166 | - pickleshare=0.7.5=py_1003 167 | - pillow=10.4.0=py311h82a398c_0 168 | - pip=24.0=pyhd8ed1ab_0 169 | - pixman=0.43.2=h59595ed_0 170 | - platformdirs=4.3.6=pyhd8ed1ab_0 171 | - prompt-toolkit=3.0.48=pyha770c72_0 172 | - psutil=6.0.0=py311h331c9d8_0 173 | - pthread-stubs=0.4=h36c2ea0_1001 174 | - ptyprocess=0.7.0=pyhd3deb0d_0 175 | - pugixml=1.14=h59595ed_0 176 | - pure_eval=0.2.3=pyhd8ed1ab_0 177 | - pycparser=2.22=pyhd8ed1ab_0 178 | - pygments=2.18.0=pyhd8ed1ab_0 179 | - pysocks=1.7.1=pyha2e5f31_6 180 | - python=3.11.9=hb806964_0_cpython 181 | - python_abi=3.11=4_cp311 182 | - pytorch=2.3.0=py3.11_cuda12.1_cudnn8.9.2_0 183 | - pytorch-cuda=12.1=ha16c6d3_5 184 | - pytorch-lightning=2.2.5=pyhd8ed1ab_0 185 | - pytorch-mutex=1.0=cuda 186 | - pyyaml=6.0.1=py311h459d7ec_1 187 | - pyzmq=26.2.0=py311h7deb3e3_3 188 | - readline=8.2=h8228510_1 189 | - requests=2.32.3=pyhd8ed1ab_0 190 | - scikit-learn=1.5.1=py311hd632256_0 191 | - scipy=1.14.0=py311h517d4fd_1 192 | - sentry-sdk=2.9.0=pyhd8ed1ab_0 193 | - setproctitle=1.3.3=py311h459d7ec_0 194 | - six=1.16.0=pyh6c4a22f_0 195 | - smmap=5.0.0=pyhd8ed1ab_0 196 | - snappy=1.2.1=ha2e4443_0 197 | - stack_data=0.6.2=pyhd8ed1ab_0 198 | - svt-av1=2.1.2=hac33072_0 199 | - sympy=1.12.1=pypyh2585a3b_103 200 | - tbb=2021.12.0=h434a139_2 201 | - threadpoolctl=3.5.0=pyhc1e730c_0 202 | - tk=8.6.13=noxft_h4845f30_101 203 | - torchmetrics=1.4.0.post0=pyhd8ed1ab_0 204 | - torchtriton=2.3.0=py311 205 | - torchvision=0.18.0=py311_cu121 206 | - tornado=6.4.1=py311h9ecbd09_1 207 | - tqdm=4.66.4=pyhd8ed1ab_0 208 | - traitlets=5.14.3=pyhd8ed1ab_0 209 | - typing-extensions=4.12.2=hd8ed1ab_0 210 | - typing_extensions=4.12.2=pyha770c72_0 211 | - urllib3=2.2.2=pyhd8ed1ab_1 212 | - wandb=0.16.6=pyhd8ed1ab_0 213 | - wayland=1.23.0=h5291e77_0 214 | - wayland-protocols=1.36=hd8ed1ab_0 215 | - wcwidth=0.2.13=pyhd8ed1ab_0 216 | - wheel=0.43.0=pyhd8ed1ab_1 217 | - widgetsnbextension=4.0.13=pyhd8ed1ab_1 218 | - x264=1!164.3095=h166bdaf_2 219 | - x265=3.5=h924138e_3 220 | - xorg-fixesproto=5.0=h7f98852_1002 221 | - xorg-kbproto=1.0.7=h7f98852_1002 222 | - xorg-libice=1.1.1=hd590300_0 223 | - xorg-libsm=1.2.4=h7391055_0 224 | - xorg-libx11=1.8.9=hb711507_1 225 | - xorg-libxau=1.0.11=hd590300_0 226 | - xorg-libxdmcp=1.1.3=h7f98852_0 227 | - xorg-libxext=1.3.4=h0b41bf4_2 228 | - xorg-libxfixes=5.0.3=h7f98852_1004 229 | - xorg-libxrender=0.9.11=hd590300_0 230 | - xorg-renderproto=0.11.1=h7f98852_1002 231 | - xorg-xextproto=7.3.0=h0b41bf4_1003 232 | - xorg-xproto=7.0.31=h7f98852_1007 233 | - xz=5.2.6=h166bdaf_0 234 | - yaml=0.2.5=h7f98852_2 235 | - zeromq=4.3.5=h3b0a872_7 236 | - zlib=1.3.1=h4ab18f5_1 237 | - zstandard=0.22.0=py311hb6f056b_1 238 | - zstd=1.5.6=ha6fb4c9_0 239 | - pip: 240 | - git+https://github.com/wendlerc/latent_clip.git@renamed 241 | - accelerate==0.32.1 242 | - aiohttp==3.9.5 243 | - aiosignal==1.3.1 244 | - args==0.1.0 245 | - attrs==23.2.0 246 | - blobfile==2.1.1 247 | - braceexpand==0.1.7 248 | - clint==0.5.1 249 | - coloredlogs==15.0.1 250 | - contourpy==1.2.1 251 | - cycler==0.12.1 252 | - datasets==2.18.0 253 | - diffusers==0.28.0 254 | - dill==0.3.8 255 | - einops==0.8.0 256 | - fairscale==0.4.13 257 | - fonttools==4.53.1 258 | - frozenlist==1.4.1 259 | - fsspec==2024.2.0 260 | - ftfy==6.2.0 261 | - hpsv2==1.2.0 262 | - huggingface-hub==0.23.4 263 | - humanfriendly==10.0 264 | - image-reward==1.5 265 | - importlib-metadata==8.0.0 266 | - iniconfig==2.0.0 267 | - kiwisolver==1.4.5 268 | - lxml==4.9.4 269 | - matplotlib==3.9.1 270 | - multidict==6.0.5 271 | - multiprocess==0.70.16 272 | - numpy==1.26.4 273 | - open-clip-torch==2.24.0 274 | - openai-clip==1.0.1 275 | - optimum==1.21.1 276 | - pandas==2.2.2 277 | - pluggy==1.5.0 278 | - protobuf==3.20.3 279 | - pyarrow==16.1.0 280 | - pyarrow-hotfix==0.6 281 | - pycryptodomex==3.20.0 282 | - pyparsing==3.1.2 283 | - pytest==7.2.0 284 | - pytest-split==0.8.0 285 | - python-dateutil==2.9.0.post0 286 | - pytz==2024.1 287 | - regex==2024.5.15 288 | - safetensors==0.4.3 289 | - sentencepiece==0.2.0 290 | - setuptools==60.2.0 291 | - timm==0.6.13 292 | - tokenizers==0.15.2 293 | - transformers==4.38.2 294 | - tzdata==2024.1 295 | - webdataset==0.2.86 296 | - xformers==0.0.26.post1 297 | - xxhash==3.4.1 298 | - yarl==1.9.4 299 | - zipp==3.19.2 300 | prefix: ./.conda/envs/latentclipenv 301 | -------------------------------------------------------------------------------- /helper.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass, field 3 | from typing import Optional, List 4 | import blobfile as bf 5 | import logging 6 | import torch 7 | import matplotlib.pyplot as plt 8 | from PIL import Image 9 | from huggingface_hub import hf_hub_download 10 | import shutil 11 | 12 | import sys 13 | 14 | sys.path.insert(0, "./Latent-ReNO/") 15 | 16 | from arguments import parse_args 17 | from models import get_model 18 | from rewards import get_reward_losses, get_latent_reward_losses 19 | from training import LatentNoiseTrainer, get_optimizer 20 | 21 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 22 | dtype = torch.bfloat16 23 | 24 | 25 | @dataclass 26 | class Args: 27 | softpromptdecay = 0.03 28 | disable_trainable_latents = False 29 | cache_dir: str = "./cache" 30 | save_dir: str = "./outputs/" 31 | model: str = "sdxl-turbo" 32 | lr: float = 5.0 33 | n_iters: int = 50 34 | n_inference_steps: int = 1 35 | optim: str = "sgd" 36 | nesterov: bool = True 37 | grad_clip: float = 0.1 38 | seed: int = 0 39 | enable_hps: bool = True 40 | hps_weighting: float = 10.0 41 | enable_imagereward: bool = True 42 | imagereward_weighting: float = 1.0 43 | enable_only_latents: bool = True 44 | enable_trainable_prompt: bool = False 45 | enable_clip_clf: bool = True 46 | maximize: bool = False 47 | enable_clip_text: bool = False 48 | enable_clip_image: bool = False 49 | enable_clip: bool = False 50 | clip_model: Optional[str] = None 51 | clip_weighting: float = 0.01 52 | latent_guidance_prompt: Optional[str] = None 53 | enable_pickscore: bool = True 54 | pickscore_weighting: float = -0.1 55 | pickscore_weighting: float = -0.1 56 | enable_aesthetic: bool = False 57 | aesthetic_weighting: float = -0.1 58 | enable_md_aesthetic: bool = False 59 | md_aesthetic_weighting: float = -0.1 60 | enable_sh_aesthetic: bool = False 61 | sh_aesthetic_weighting: float = 0.1 62 | enable_pgen: bool = False 63 | pgen_weighting: float = 1.0 64 | enable_nsfw: bool = False 65 | nsfw_weighting: float = 1.0 66 | enable_reg: bool = True 67 | reg_weight: float = 0.01 68 | task: str = "single" 69 | prompt: str = "A green elephant and a red mouse" 70 | negative_prompt: Optional[str] = None 71 | benchmark_reward: str = "total" 72 | save_all_images: bool = True 73 | save_gif: bool = True 74 | no_optim: bool = False 75 | imageselect: bool = False 76 | memsave: bool = False 77 | device: str = "cuda" 78 | device_id: Optional[int] = None 79 | 80 | 81 | def get_sd_model(args): 82 | return get_model(args, args.model, dtype, device, args.cache_dir, args.memsave) 83 | 84 | 85 | def get_latent_noise_trainer(args, sd_model): 86 | bf.makedirs(f"{args.save_dir}/logs/{args.task}") 87 | settings = ( 88 | f"{args.model}_max-{args.maximize}_{args.latent_guidance_prompt}_{args.optim}" 89 | ) 90 | 91 | logger = logging.getLogger() 92 | if not logger.hasHandlers(): 93 | file_stream = open(f"{args.save_dir}/logs/{args.task}/{settings}.txt", "w") 94 | handler = logging.StreamHandler(file_stream) 95 | formatter = logging.Formatter("%(asctime)s - %(message)s") 96 | handler.setFormatter(formatter) 97 | logger.addHandler(handler) 98 | logger.setLevel("INFO") 99 | consoleHandler = logging.StreamHandler() 100 | consoleHandler.setFormatter(formatter) 101 | logger.addHandler(consoleHandler) 102 | 103 | if args.device_id is not None: 104 | logging.info(f"Using CUDA device {args.device_id}") 105 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 106 | os.environ["CUDA_VISIBLE_DEVICE"] = args.device_id 107 | 108 | if "latent" in args.clip_model.lower(): 109 | reward_losses = get_latent_reward_losses(args, dtype, device, args.cache_dir) 110 | else: 111 | reward_losses = get_reward_losses(args, dtype, device, args.cache_dir) 112 | 113 | trainer = LatentNoiseTrainer( 114 | reward_losses=reward_losses, 115 | model=sd_model, 116 | n_iters=args.n_iters, 117 | n_inference_steps=args.n_inference_steps, 118 | seed=args.seed, 119 | save_all_images=args.save_all_images, 120 | save_gif=args.save_gif, 121 | device=device, 122 | no_optim=args.no_optim, 123 | regularize=args.enable_reg, 124 | regularization_weight=args.reg_weight, 125 | grad_clip=args.grad_clip, 126 | log_metrics=args.task == "single" or not args.no_optim, 127 | imageselect=args.imageselect, 128 | optim=args.optim, 129 | ) 130 | 131 | return trainer, settings 132 | 133 | 134 | def generate_and_optimize(args, trainer, sd_model, settings): 135 | height = sd_model.unet.config.sample_size * sd_model.vae_scale_factor 136 | width = sd_model.unet.config.sample_size * sd_model.vae_scale_factor 137 | shape = ( 138 | 1, 139 | sd_model.unet.in_channels, 140 | height // sd_model.vae_scale_factor, 141 | width // sd_model.vae_scale_factor, 142 | ) 143 | 144 | init_latents = torch.randn(shape, device=device, dtype=dtype) 145 | init_prompt = torch.randn([1, 77, 2048], device=device, dtype=dtype) 146 | init_add_text = torch.randn([1, 1280], device=device, dtype=dtype) 147 | 148 | latents = torch.nn.Parameter(init_latents, requires_grad=True) 149 | 150 | optimizer = get_optimizer(args.optim, latents, args.lr, args.nesterov) 151 | save_dir = f"{args.save_dir}/{args.task}/{settings}/{args.prompt[:100]}" 152 | os.makedirs(f"{save_dir}", exist_ok=True) 153 | best_image, total_best_rewards, initial_image_pil, total_init_rewards = ( 154 | trainer.train( 155 | latents, 156 | args.prompt, 157 | optimizer, 158 | save_dir, 159 | negative_prompt=args.negative_prompt, 160 | ) 161 | ) 162 | best_image.save(f"{save_dir}/best_image.png") 163 | 164 | return save_dir 165 | 166 | 167 | def plot_images(save_dir, num_inference_steps=50, only_best=False): 168 | if only_best: 169 | best_image_path = os.path.join(save_dir, "best_image.png") 170 | if os.path.exists(best_image_path): 171 | image = Image.open(best_image_path) 172 | plt.figure(figsize=(2, 2)) 173 | plt.imshow(image) 174 | plt.axis("off") 175 | plt.title("Best Image") 176 | plt.show() 177 | else: 178 | print("'best_image.png' not found in the directory.") 179 | return 180 | 181 | # Handling multiple image plotting when only_best is False 182 | valid_filenames = {f"{i:02d}.png" for i in range(num_inference_steps)} 183 | image_files = sorted([f for f in os.listdir(save_dir) if f in valid_filenames]) 184 | 185 | num_images = len(image_files) 186 | if num_images == 0: 187 | print("No valid images found in the directory.") 188 | return 189 | 190 | cols = min(10, num_images) # Max 10 columns for better readability 191 | rows = (num_images + cols - 1) // cols 192 | 193 | fig, axes = plt.subplots(rows, cols, figsize=(cols * 2, rows * 2)) 194 | axes = axes.flatten() if num_images > 1 else [axes] 195 | 196 | for ax, img_name in zip(axes, image_files): 197 | img_path = os.path.join(save_dir, img_name) 198 | image = Image.open(img_path) 199 | 200 | ax.imshow(image) 201 | ax.axis("off") 202 | ax.set_title(img_name.split(".")[0]) 203 | 204 | for ax in axes[num_images:]: 205 | ax.set_visible(False) 206 | 207 | plt.tight_layout() 208 | plt.show() 209 | 210 | 211 | def download_and_rename_model(repo_id, save_as, target_dir, filename="epoch_34.pt"): 212 | target_path = os.path.join(target_dir, save_as) 213 | 214 | if os.path.exists(target_path): 215 | print(f"✅ {save_as} already exists. Skipping download.") 216 | return 217 | 218 | os.makedirs(target_dir, exist_ok=True) 219 | 220 | downloaded_file = hf_hub_download(repo_id=repo_id, filename=filename) 221 | 222 | shutil.copy(downloaded_file, target_path) 223 | 224 | print(f"✅ {filename} downloaded from {repo_id} and saved as {target_path}") 225 | --------------------------------------------------------------------------------