├── README.md
├── assets
├── pixart.png
├── sd1.4.png
├── sd1.5.png
├── sd2.1.png
├── teaser.png
└── transfer
│ ├── ltx-video.gif
│ ├── videocrafter2.gif
│ └── zeroscope.gif
├── checkpoints
├── pixart-alpha_reneg_emb.bin
├── sd1.4_reneg_emb.bin
├── sd1.5_reneg_emb.bin
└── sd2.1_reneg_emb.bin
├── inference.py
└── requirements.txt
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
ReNeg: Learning Negative Embedding with Reward Guidance
4 |
5 | [](https://arxiv.org/abs/2412.19637)
6 |
9 |
10 |
11 |
12 | We present **ReNeg**, a **Re**ward-guided approach that directly learns **Neg**ative embeddings through gradient descent. The global negative embeddings learned using **ReNeg** exhibit strong generalization capabilities and can be seamlessly adaptable to text-to-image and even text-to-video models. Strikingly simple yet highly effective, **ReNeg** amplifies the visual appeal of outputs from base Stable Diffusion models.
13 |
14 |
15 | If you find `ReNeg`'s open-source effort useful, please 🌟 us to encourage our following development!
16 | ## 🔧 Installation
17 | ```bash
18 | conda create -n reneg python=3.8.5
19 | pip install torch==2.0.0 torchvision==0.15.1 torchaudio==2.0.1 --index-url https://download.pytorch.org/whl/cu118
20 |
21 | # Clone the ReNeg repository
22 | git clone https://github.com/XiaominLi1997/ReNeg.git
23 | cd ReNeg
24 | pip install -r requirements.txt
25 | ```
26 | ## 🗄️ Models and Demos
27 | Any text-conditioned generative model utilizing the same text encoder can share their negative embeddings. We provide the following ReNeg embeddings of common text encoders.
28 | ### 1. Models
29 |
56 |
57 |
58 |
59 | ### 2. Demos
60 |
61 | #### Negative Embeddings of SD and Pixart-alpha:
62 |
63 |
64 |
65 | Text Encoder |
66 | Model |
67 | Results |
68 |
70 |
71 |
72 | CLIP ViT-L/14 |
73 | SD1.4 |
74 |  |
75 |
76 |
78 |
79 |
80 |
81 |
82 | SD1.5 |
83 |  |
84 |
85 |
88 |
89 |
90 | OpenCLIP-ViT/H |
91 | SD2.1 |
92 |  |
93 |
94 |
95 |
96 | T5-v1.1-xxl |
97 | Pixart-alpha |
98 |  |
99 |
100 |
101 |
102 | #### Transfer of Negative Embeddings:
103 |
104 |
105 | Text Encoder |
106 | Transfer of Neg. Emb. |
107 | Results |
108 |
109 |
111 |
112 |
113 | OpenCLIP-ViT/H |
114 | SD2.1 -> ZeroScope |
115 |  |
116 |
117 |
118 |
119 | SD2.1 -> VideoCrafter2 |
120 |  |
121 |
122 |
123 | T5-v1.1-xxl |
124 | Pixart-alpha -> LTX-Video |
125 |  |
126 |
127 |
128 |
129 |
141 |
142 |
143 |
173 |
174 | ## 💻 Inference
175 | You need to first specify the paths for `SD1.5` and `neg_emb`. By default, we place `neg_emb` under the `checkpoints/` directory.
176 | ```bash
177 | python inference.py --model_path "your_sd1.5_path" --neg_embeddings_path "checkpoints/checkpoint.bin" --prompt "A girl in a school uniform playing an electric guitar."
178 | ```
179 |
180 | To compare with the inference results using `neg_emb`, you can perform inference using only positive prompt, or use a specific negative prompt.
181 | + To perform **inference using only the pos_prompt**, you need to specify `args.prompt_type = only_pos`.
182 | ```bash
183 | python inference.py --model_path "your_sd1.5_path" --prompt_type "only_pos" --prompt "A girl in a school uniform playing an electric guitar."
184 | ```
185 | + To perform **inference using pos_prompt + neg_prompt**, example negative prompts include: `distorted, ugly, blurry, low resolution, low quality, bad, deformed, disgusting, Overexposed, Simple background, Plain background, Grainy, Underexposed, too dark, too bright, too low contrast, too high contrast, Broken, Macabre, artifacts, oversaturated`
186 | ```bash
187 | python inference.py --model_path "your_sd1.5_path" --prompt_type "neg_prompt" --prompt "A girl in a school uniform playing an electric guitar."
188 | ```
189 |
190 | ## 📋 Todo List
191 | - [x] Inference code
192 | - [ ] Training code
193 | - [ ] Online Demo
194 |
195 | ## ❤️ Acknowledgements
196 | This project is based on [ImageReward](https://github.com/THUDM/ImageReward) and [diffusers](https://github.com/huggingface/diffusers). Thanks for their awesome works.
197 |
198 |
199 | ## Citation
200 |
201 | ```
202 | @misc{li2024reneg,
203 | title={ReNeg: Learning Negative Embedding with Reward Guidance},
204 | author={Xiaomin Li, Yixuan Liu, Takashi Isobe, Xu Jia, Qinpeng Cui, Dong Zhou, Dong Li, You He, Huchuan Lu, Zhongdao Wang, Emad Barsoum},
205 | year={2024},
206 | eprint={2412.19637},
207 | archivePrefix={arXiv},
208 | primaryClass={cs.CV}
209 | }
210 | ```
211 |
--------------------------------------------------------------------------------
/assets/pixart.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LemonTwoL/ReNeg/8b4192b4b1b06440e61e344e3fa48b218bd49f64/assets/pixart.png
--------------------------------------------------------------------------------
/assets/sd1.4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LemonTwoL/ReNeg/8b4192b4b1b06440e61e344e3fa48b218bd49f64/assets/sd1.4.png
--------------------------------------------------------------------------------
/assets/sd1.5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LemonTwoL/ReNeg/8b4192b4b1b06440e61e344e3fa48b218bd49f64/assets/sd1.5.png
--------------------------------------------------------------------------------
/assets/sd2.1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LemonTwoL/ReNeg/8b4192b4b1b06440e61e344e3fa48b218bd49f64/assets/sd2.1.png
--------------------------------------------------------------------------------
/assets/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LemonTwoL/ReNeg/8b4192b4b1b06440e61e344e3fa48b218bd49f64/assets/teaser.png
--------------------------------------------------------------------------------
/assets/transfer/ltx-video.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LemonTwoL/ReNeg/8b4192b4b1b06440e61e344e3fa48b218bd49f64/assets/transfer/ltx-video.gif
--------------------------------------------------------------------------------
/assets/transfer/videocrafter2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LemonTwoL/ReNeg/8b4192b4b1b06440e61e344e3fa48b218bd49f64/assets/transfer/videocrafter2.gif
--------------------------------------------------------------------------------
/assets/transfer/zeroscope.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LemonTwoL/ReNeg/8b4192b4b1b06440e61e344e3fa48b218bd49f64/assets/transfer/zeroscope.gif
--------------------------------------------------------------------------------
/checkpoints/pixart-alpha_reneg_emb.bin:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LemonTwoL/ReNeg/8b4192b4b1b06440e61e344e3fa48b218bd49f64/checkpoints/pixart-alpha_reneg_emb.bin
--------------------------------------------------------------------------------
/checkpoints/sd1.4_reneg_emb.bin:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LemonTwoL/ReNeg/8b4192b4b1b06440e61e344e3fa48b218bd49f64/checkpoints/sd1.4_reneg_emb.bin
--------------------------------------------------------------------------------
/checkpoints/sd1.5_reneg_emb.bin:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LemonTwoL/ReNeg/8b4192b4b1b06440e61e344e3fa48b218bd49f64/checkpoints/sd1.5_reneg_emb.bin
--------------------------------------------------------------------------------
/checkpoints/sd2.1_reneg_emb.bin:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LemonTwoL/ReNeg/8b4192b4b1b06440e61e344e3fa48b218bd49f64/checkpoints/sd2.1_reneg_emb.bin
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 | import torch
4 | from diffusers import (
5 | StableDiffusionPipeline,
6 | DPMSolverMultistepScheduler,
7 | DDIMScheduler,
8 | )
9 | import argparse
10 |
11 | def parse_args():
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument(
14 | "--model_path",
15 | type=str,
16 | default="",
17 | )
18 | # parser.add_argument("--prompt_path", type=str, default="prompts/prompts.txt")
19 | parser.add_argument(
20 | "--prompt",
21 | type=str,
22 | default="A girl in a school uniform playing an electric guitar.",
23 | )
24 | parser.add_argument(
25 | "--prompt_type",
26 | type=str,
27 | default="neg_emb",
28 | choices=["neg_emb", "neg_prompt", "only_pos"],
29 | )
30 | parser.add_argument(
31 | "--neg_prompt",
32 | type=str,
33 | default="distorted, ugly, blurry, low resolution, low quality, bad, deformed, disgusting, Overexposed, Simple background, Plain background, Grainy, Underexposed, too dark, too bright, too low contrast, too high contrast, Broken, Macabre, artifacts, oversaturated",
34 | )
35 | parser.add_argument(
36 | "--neg_embeddings_path",
37 | type=str,
38 | default="checkpoints/sd1.5_reneg_emb.bin",
39 | )
40 | parser.add_argument(
41 | "--output_path",
42 | type=str,
43 | default="outputs",
44 | )
45 | parser.add_argument("--num_inference_steps", type=int, default=30)
46 | parser.add_argument("--seed", type=int, default=42)
47 | return parser.parse_args()
48 |
49 |
50 | if __name__ == "__main__":
51 | args = parse_args()
52 | pipe = StableDiffusionPipeline.from_pretrained(
53 | args.model_path,
54 | safety_checker=None,
55 | )
56 | pipe.scheduler = DDIMScheduler.from_pretrained(
57 | args.model_path, subfolder="scheduler"
58 | )
59 | device = "cuda"
60 | pipe.to(device)
61 | generator = torch.Generator().manual_seed(args.seed)
62 |
63 | os.makedirs(args.output_path, exist_ok=True)
64 | if args.prompt_type == "neg_emb":
65 | neg_embeddings = torch.load(args.neg_embeddings_path).to(device)
66 | output = pipe(
67 | args.prompt,
68 | negative_prompt_embeds=neg_embeddings,
69 | num_inference_steps=args.num_inference_steps,
70 | guidance_scale=7.5,
71 | generator=generator,
72 | )
73 | elif args.prompt_type == "neg_prompt":
74 | neg_prompt = args.neg_prompt
75 | output = pipe(
76 | args.prompt,
77 | negative_prompt=neg_prompt,
78 | num_inference_steps=args.num_inference_steps,
79 | guidance_scale=7.5,
80 | generator=generator,
81 | )
82 | elif args.prompt_type == "only_pos":
83 | output = pipe(
84 | args.prompt,
85 | num_inference_steps=args.num_inference_steps,
86 | guidance_scale=7.5,
87 | generator=generator,
88 | )
89 | image = output.images[0]
90 | # TextToImageModel is the model you want to evaluate
91 | file_name = args.prompt.replace(" ", "_")
92 | output_file = Path(args.output_path) / f"{args.prompt_type}_{file_name}.jpg"
93 | image.save(output_file)
94 | print(f"Saved image to {output_file}")
95 |
96 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==2.0.0+cu118
2 | torchvision==0.15.1+cu118
3 | datasets==2.14.6
4 | diffusers==0.20.0
5 | tqdm==4.66.5
6 | transformers==4.25.1
7 | huggingface-hub==0.24.5
8 | fairscale==0.4.13
9 | timm==1.0.9
10 | accelerate==0.20.0
11 | clip==0.2.0
12 |
--------------------------------------------------------------------------------