├── README.md ├── assets ├── pixart.png ├── sd1.4.png ├── sd1.5.png ├── sd2.1.png ├── teaser.png └── transfer │ ├── ltx-video.gif │ ├── videocrafter2.gif │ └── zeroscope.gif ├── checkpoints ├── pixart-alpha_reneg_emb.bin ├── sd1.4_reneg_emb.bin ├── sd1.5_reneg_emb.bin └── sd2.1_reneg_emb.bin ├── inference.py └── requirements.txt /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 |

ReNeg: Learning Negative Embedding with Reward Guidance

4 | 5 | [![arXiv](https://img.shields.io/badge/arXiv%20paper-2406.11838-b31b1b.svg)](https://arxiv.org/abs/2412.19637)  6 | 9 |

10 |
11 | 12 | We present **ReNeg**, a **Re**ward-guided approach that directly learns **Neg**ative embeddings through gradient descent. The global negative embeddings learned using **ReNeg** exhibit strong generalization capabilities and can be seamlessly adaptable to text-to-image and even text-to-video models. Strikingly simple yet highly effective, **ReNeg** amplifies the visual appeal of outputs from base Stable Diffusion models. 13 | 14 | 15 | If you find `ReNeg`'s open-source effort useful, please 🌟 us to encourage our following development! 16 | ## 🔧 Installation 17 | ```bash 18 | conda create -n reneg python=3.8.5 19 | pip install torch==2.0.0 torchvision==0.15.1 torchaudio==2.0.1 --index-url https://download.pytorch.org/whl/cu118 20 | 21 | # Clone the ReNeg repository 22 | git clone https://github.com/XiaominLi1997/ReNeg.git 23 | cd ReNeg 24 | pip install -r requirements.txt 25 | ``` 26 | ## 🗄️ Models and Demos 27 | Any text-conditioned generative model utilizing the same text encoder can share their negative embeddings. We provide the following ReNeg embeddings of common text encoders. 28 | ### 1. Models 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 |
Text EncoderModelPath
CLIP ViT-L/14SD1.4sd1.4_reneg_emb
SD1.5sd1.5_reneg_emb
OpenCLIP-ViT/HSD2.1sd2.1_reneg_emb
T5-v1.1-xxlPixart-alphapixart-alpha_reneg_emb
56 | 57 | 58 | 59 | ### 2. Demos 60 | 61 | #### Negative Embeddings of SD and Pixart-alpha: 62 | 63 | 64 | 65 | 66 | 67 | 68 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 |
Text EncoderModelResults
CLIP ViT-L/14SD1.4
SD1.5
OpenCLIP-ViT/HSD2.1
T5-v1.1-xxlPixart-alpha
101 | 102 | #### Transfer of Negative Embeddings: 103 | 104 | 105 | 106 | 107 | 108 | 109 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 141 |
Text EncoderTransfer of Neg. Emb.Results
OpenCLIP-ViT/HSD2.1 -> ZeroScope
SD2.1 -> VideoCrafter2
T5-v1.1-xxlPixart-alpha -> LTX-Video
142 | 143 | 173 | 174 | ## 💻 Inference 175 | You need to first specify the paths for `SD1.5` and `neg_emb`. By default, we place `neg_emb` under the `checkpoints/` directory. 176 | ```bash 177 | python inference.py --model_path "your_sd1.5_path" --neg_embeddings_path "checkpoints/checkpoint.bin" --prompt "A girl in a school uniform playing an electric guitar." 178 | ``` 179 | 180 | To compare with the inference results using `neg_emb`, you can perform inference using only positive prompt, or use a specific negative prompt. 181 | + To perform **inference using only the pos_prompt**, you need to specify `args.prompt_type = only_pos`. 182 | ```bash 183 | python inference.py --model_path "your_sd1.5_path" --prompt_type "only_pos" --prompt "A girl in a school uniform playing an electric guitar." 184 | ``` 185 | + To perform **inference using pos_prompt + neg_prompt**, example negative prompts include: `distorted, ugly, blurry, low resolution, low quality, bad, deformed, disgusting, Overexposed, Simple background, Plain background, Grainy, Underexposed, too dark, too bright, too low contrast, too high contrast, Broken, Macabre, artifacts, oversaturated` 186 | ```bash 187 | python inference.py --model_path "your_sd1.5_path" --prompt_type "neg_prompt" --prompt "A girl in a school uniform playing an electric guitar." 188 | ``` 189 | 190 | ## 📋 Todo List 191 | - [x] Inference code 192 | - [ ] Training code 193 | - [ ] Online Demo 194 | 195 | ## ❤️ Acknowledgements 196 | This project is based on [ImageReward](https://github.com/THUDM/ImageReward) and [diffusers](https://github.com/huggingface/diffusers). Thanks for their awesome works. 197 | 198 | 199 | ## Citation 200 | 201 | ``` 202 | @misc{li2024reneg, 203 | title={ReNeg: Learning Negative Embedding with Reward Guidance}, 204 | author={Xiaomin Li, Yixuan Liu, Takashi Isobe, Xu Jia, Qinpeng Cui, Dong Zhou, Dong Li, You He, Huchuan Lu, Zhongdao Wang, Emad Barsoum}, 205 | year={2024}, 206 | eprint={2412.19637}, 207 | archivePrefix={arXiv}, 208 | primaryClass={cs.CV} 209 | } 210 | ``` 211 | -------------------------------------------------------------------------------- /assets/pixart.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LemonTwoL/ReNeg/8b4192b4b1b06440e61e344e3fa48b218bd49f64/assets/pixart.png -------------------------------------------------------------------------------- /assets/sd1.4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LemonTwoL/ReNeg/8b4192b4b1b06440e61e344e3fa48b218bd49f64/assets/sd1.4.png -------------------------------------------------------------------------------- /assets/sd1.5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LemonTwoL/ReNeg/8b4192b4b1b06440e61e344e3fa48b218bd49f64/assets/sd1.5.png -------------------------------------------------------------------------------- /assets/sd2.1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LemonTwoL/ReNeg/8b4192b4b1b06440e61e344e3fa48b218bd49f64/assets/sd2.1.png -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LemonTwoL/ReNeg/8b4192b4b1b06440e61e344e3fa48b218bd49f64/assets/teaser.png -------------------------------------------------------------------------------- /assets/transfer/ltx-video.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LemonTwoL/ReNeg/8b4192b4b1b06440e61e344e3fa48b218bd49f64/assets/transfer/ltx-video.gif -------------------------------------------------------------------------------- /assets/transfer/videocrafter2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LemonTwoL/ReNeg/8b4192b4b1b06440e61e344e3fa48b218bd49f64/assets/transfer/videocrafter2.gif -------------------------------------------------------------------------------- /assets/transfer/zeroscope.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LemonTwoL/ReNeg/8b4192b4b1b06440e61e344e3fa48b218bd49f64/assets/transfer/zeroscope.gif -------------------------------------------------------------------------------- /checkpoints/pixart-alpha_reneg_emb.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LemonTwoL/ReNeg/8b4192b4b1b06440e61e344e3fa48b218bd49f64/checkpoints/pixart-alpha_reneg_emb.bin -------------------------------------------------------------------------------- /checkpoints/sd1.4_reneg_emb.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LemonTwoL/ReNeg/8b4192b4b1b06440e61e344e3fa48b218bd49f64/checkpoints/sd1.4_reneg_emb.bin -------------------------------------------------------------------------------- /checkpoints/sd1.5_reneg_emb.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LemonTwoL/ReNeg/8b4192b4b1b06440e61e344e3fa48b218bd49f64/checkpoints/sd1.5_reneg_emb.bin -------------------------------------------------------------------------------- /checkpoints/sd2.1_reneg_emb.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LemonTwoL/ReNeg/8b4192b4b1b06440e61e344e3fa48b218bd49f64/checkpoints/sd2.1_reneg_emb.bin -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | import torch 4 | from diffusers import ( 5 | StableDiffusionPipeline, 6 | DPMSolverMultistepScheduler, 7 | DDIMScheduler, 8 | ) 9 | import argparse 10 | 11 | def parse_args(): 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument( 14 | "--model_path", 15 | type=str, 16 | default="", 17 | ) 18 | # parser.add_argument("--prompt_path", type=str, default="prompts/prompts.txt") 19 | parser.add_argument( 20 | "--prompt", 21 | type=str, 22 | default="A girl in a school uniform playing an electric guitar.", 23 | ) 24 | parser.add_argument( 25 | "--prompt_type", 26 | type=str, 27 | default="neg_emb", 28 | choices=["neg_emb", "neg_prompt", "only_pos"], 29 | ) 30 | parser.add_argument( 31 | "--neg_prompt", 32 | type=str, 33 | default="distorted, ugly, blurry, low resolution, low quality, bad, deformed, disgusting, Overexposed, Simple background, Plain background, Grainy, Underexposed, too dark, too bright, too low contrast, too high contrast, Broken, Macabre, artifacts, oversaturated", 34 | ) 35 | parser.add_argument( 36 | "--neg_embeddings_path", 37 | type=str, 38 | default="checkpoints/sd1.5_reneg_emb.bin", 39 | ) 40 | parser.add_argument( 41 | "--output_path", 42 | type=str, 43 | default="outputs", 44 | ) 45 | parser.add_argument("--num_inference_steps", type=int, default=30) 46 | parser.add_argument("--seed", type=int, default=42) 47 | return parser.parse_args() 48 | 49 | 50 | if __name__ == "__main__": 51 | args = parse_args() 52 | pipe = StableDiffusionPipeline.from_pretrained( 53 | args.model_path, 54 | safety_checker=None, 55 | ) 56 | pipe.scheduler = DDIMScheduler.from_pretrained( 57 | args.model_path, subfolder="scheduler" 58 | ) 59 | device = "cuda" 60 | pipe.to(device) 61 | generator = torch.Generator().manual_seed(args.seed) 62 | 63 | os.makedirs(args.output_path, exist_ok=True) 64 | if args.prompt_type == "neg_emb": 65 | neg_embeddings = torch.load(args.neg_embeddings_path).to(device) 66 | output = pipe( 67 | args.prompt, 68 | negative_prompt_embeds=neg_embeddings, 69 | num_inference_steps=args.num_inference_steps, 70 | guidance_scale=7.5, 71 | generator=generator, 72 | ) 73 | elif args.prompt_type == "neg_prompt": 74 | neg_prompt = args.neg_prompt 75 | output = pipe( 76 | args.prompt, 77 | negative_prompt=neg_prompt, 78 | num_inference_steps=args.num_inference_steps, 79 | guidance_scale=7.5, 80 | generator=generator, 81 | ) 82 | elif args.prompt_type == "only_pos": 83 | output = pipe( 84 | args.prompt, 85 | num_inference_steps=args.num_inference_steps, 86 | guidance_scale=7.5, 87 | generator=generator, 88 | ) 89 | image = output.images[0] 90 | # TextToImageModel is the model you want to evaluate 91 | file_name = args.prompt.replace(" ", "_") 92 | output_file = Path(args.output_path) / f"{args.prompt_type}_{file_name}.jpg" 93 | image.save(output_file) 94 | print(f"Saved image to {output_file}") 95 | 96 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.0.0+cu118 2 | torchvision==0.15.1+cu118 3 | datasets==2.14.6 4 | diffusers==0.20.0 5 | tqdm==4.66.5 6 | transformers==4.25.1 7 | huggingface-hub==0.24.5 8 | fairscale==0.4.13 9 | timm==1.0.9 10 | accelerate==0.20.0 11 | clip==0.2.0 12 | --------------------------------------------------------------------------------