├── requirements-server.txt ├── assets ├── rewardanything-logo-horizontal.png └── rewardanything-logo-horizontal-dark-mode.png ├── requirements-dev.txt ├── requirements.txt ├── MANIFEST.in ├── examples ├── server_config.json ├── remote_usage.py ├── transformers_usage.py └── local_usage.py ├── rewardanything ├── __init__.py ├── models.py ├── cli.py ├── local.py ├── client.py ├── utils.py ├── processing.py └── serve.py ├── pages ├── _config.yml └── index.html ├── .gitignore ├── pyproject.toml ├── setup.py ├── LICENSE ├── docs └── PROJECT_DOCS.md └── README.md /requirements-server.txt: -------------------------------------------------------------------------------- 1 | fastapi>=0.68.0 2 | uvicorn[standard]>=0.15.0 3 | httpx>=0.24.0 4 | openai>=1.0.0 -------------------------------------------------------------------------------- /assets/rewardanything-logo-horizontal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WisdomShell/RewardAnything/HEAD/assets/rewardanything-logo-horizontal.png -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | pytest>=6.0.0 2 | pytest-asyncio>=0.18.0 3 | black>=21.0.0 4 | isort>=5.9.0 5 | flake8>=3.9.0 6 | mypy>=0.910 7 | pre-commit>=2.15.0 -------------------------------------------------------------------------------- /assets/rewardanything-logo-horizontal-dark-mode.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WisdomShell/RewardAnything/HEAD/assets/rewardanything-logo-horizontal-dark-mode.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=2.0.0 2 | transformers>=4.51.0 3 | tokenizers>=0.13.0 4 | requests>=2.25.0 5 | pydantic>=1.8.0 6 | tqdm>=4.62.0 7 | numpy>=1.21.0 8 | accelerate>=1.7.0 -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md 2 | include LICENSE 3 | include pyproject.toml 4 | recursive-include rewardanything *.py 5 | recursive-include examples *.py *.json 6 | recursive-exclude * __pycache__ 7 | recursive-exclude * *.py[co] -------------------------------------------------------------------------------- /examples/server_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "api_key": ["dummy-key-for-local-vllm"], 3 | "api_model": "zhuohaoyu/RewardAnything-8B-v1", 4 | "api_base": ["http://localhost:8000/v1"], 5 | "api_proxy": null, 6 | "api_timeout": 120.0, 7 | "api_max_retries": 3, 8 | "generation_config": { 9 | "temperature": 0.0, 10 | "max_tokens": 4096, 11 | "top_p": 1.0, 12 | "frequency_penalty": 0.0, 13 | "presence_penalty": 0.0 14 | }, 15 | "num_workers": 4, 16 | "request_limit": 500, 17 | "request_limit_period": 60, 18 | "max_error_count": 30, 19 | "dump_individual_rsp": false 20 | } -------------------------------------------------------------------------------- /rewardanything/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | RewardAnything: Generalizable Principle-Following Reward Models 3 | 4 | This package provides both local and remote inference capabilities for 5 | RewardAnything models that can follow natural language evaluation principles. 6 | """ 7 | 8 | from .local import from_pretrained 9 | from .client import Client 10 | from .models import RewardResult, RewardRequest, RewardResponse 11 | # from .benchmarks import RABench 12 | 13 | __version__ = "1.0.1" 14 | __all__ = ["from_pretrained", "Client", "RewardResult", "RewardRequest", "RewardResponse"] 15 | 16 | # Optional benchmarks import (only if available) 17 | try: 18 | from .benchmarks import RABench 19 | __all__.append("RABench") 20 | except ImportError: 21 | pass 22 | -------------------------------------------------------------------------------- /examples/remote_usage.py: -------------------------------------------------------------------------------- 1 | import rewardanything 2 | 3 | # Connect to the RewardAnything server 4 | client = rewardanything.Client("http://localhost:8001") 5 | 6 | # Process batch requests efficiently 7 | requests = [ 8 | { 9 | "principle": "Prefer helpful and safe responses", 10 | "prompt": "How to learn programming?", 11 | "responses": { 12 | "assistant_a": "Start with Python, practice daily, build projects.", 13 | "assistant_b": "Read books and hope for the best." 14 | } 15 | }, 16 | # ... more requests 17 | ] 18 | 19 | results = client.judge_batch(requests) 20 | for result in results: 21 | print(f"Scores: {result.scores}") 22 | print(f"Best to worst: {result.ranking}") 23 | print(f"Reasoning: {result.reasoning}") -------------------------------------------------------------------------------- /rewardanything/models.py: -------------------------------------------------------------------------------- 1 | """Data models and result classes for RewardAnything.""" 2 | 3 | from typing import Dict, List, Optional, Any 4 | from dataclasses import dataclass 5 | from pydantic import BaseModel 6 | 7 | 8 | @dataclass 9 | class RewardResult: 10 | """Result from RewardAnything evaluation.""" 11 | reasoning: str 12 | scores: Dict[str, float] # model_name -> score (1-5) 13 | ranking: List[str] # ordered list from best to worst 14 | raw_output: Optional[str] = None 15 | 16 | def __str__(self) -> str: 17 | return f"RewardResult(scores={self.scores}, ranking={self.ranking})" 18 | 19 | def __repr__(self) -> str: 20 | return self.__str__() 21 | 22 | 23 | class RewardRequest(BaseModel): 24 | """Request format for RewardAnything evaluation.""" 25 | principle: str 26 | prompt: str 27 | responses: Dict[str, str] 28 | mask_responses: bool = True 29 | 30 | 31 | class RewardResponse(BaseModel): 32 | """Response format from RewardAnything server.""" 33 | thoughts: str 34 | results: Dict[str, Any] -------------------------------------------------------------------------------- /pages/_config.yml: -------------------------------------------------------------------------------- 1 | title: "RewardAnything" 2 | description: "Generalizable Principle-Following Reward Models" 3 | url: "https://zhuohaoyu.github.io" 4 | baseurl: "/RewardAnything" 5 | 6 | # Build settings 7 | markdown: kramdown 8 | highlighter: rouge 9 | plugins: 10 | - jekyll-feed 11 | - jekyll-sitemap 12 | - jekyll-seo-tag 13 | 14 | # Collections 15 | collections: 16 | authors: 17 | output: false 18 | 19 | # Exclude files 20 | exclude: 21 | - node_modules/ 22 | - package.json 23 | - package-lock.json 24 | - tailwind.config.js 25 | - postcss.config.js 26 | - Gemfile 27 | - Gemfile.lock 28 | - vendor/ 29 | - .bundle/ 30 | - README.md 31 | 32 | # Social links 33 | github_username: zhuohaoyu 34 | paper_url: "https://arxiv.org/abs/2506.03637" 35 | huggingface_url: "https://huggingface.co/WisdomShell/RewardAnything-8B-v1" 36 | pypi_url: "https://pypi.org/project/rewardanything/" 37 | 38 | # Project info 39 | version: "1.0.1" 40 | license: "Apache-2.0" 41 | 42 | # GitHub Pages specific settings 43 | github: [metadata] 44 | kramdown: 45 | input: GFM 46 | syntax_highlighter: rouge -------------------------------------------------------------------------------- /examples/transformers_usage.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer, AutoModelForCausalLM 2 | from rewardanything.processing import prepare_chat_messages, parse_rewardanything_output 3 | import torch 4 | 5 | # Load model and tokenizer directly 6 | model = AutoModelForCausalLM.from_pretrained( 7 | "zhuohaoyu/RewardAnything-8B-v1", 8 | torch_dtype="auto", 9 | device_map="auto" 10 | ) 11 | tokenizer = AutoTokenizer.from_pretrained("zhuohaoyu/RewardAnything-8B-v1") 12 | 13 | # Prepare evaluation data 14 | principle = "Judge responses based on helpfulness and accuracy" 15 | prompt = "What is the capital of France?" 16 | responses = { 17 | "model_a": "Paris is the capital of France.", 18 | "model_b": "I think it might be Lyon or Paris." 19 | } 20 | 21 | # Prepare chat messages (handles masking automatically) 22 | messages, masked2real = prepare_chat_messages(principle, prompt, responses) 23 | 24 | # Format with chat template 25 | formatted_input = tokenizer.apply_chat_template( 26 | messages, tokenize=False, add_generation_prompt=True 27 | ) 28 | 29 | # Generate response 30 | inputs = tokenizer(formatted_input, return_tensors="pt").to(model.device) 31 | with torch.no_grad(): 32 | outputs = model.generate( 33 | **inputs, 34 | max_new_tokens=4096, 35 | temperature=0.1, 36 | do_sample=True, 37 | pad_token_id=tokenizer.eos_token_id 38 | ) 39 | 40 | # Parse structured results (handles JSON parsing robustly) 41 | output_text = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True) 42 | result = parse_rewardanything_output(output_text, masked2real) 43 | 44 | print(f"Parsed scores: {result.scores}") 45 | print(f"Ranking: {result.ranking}") 46 | print(f"Reasoning: {result.reasoning}") -------------------------------------------------------------------------------- /rewardanything/cli.py: -------------------------------------------------------------------------------- 1 | """Command-line interface for RewardAnything.""" 2 | 3 | import argparse 4 | import sys 5 | 6 | 7 | def main(): 8 | """Main CLI entry point.""" 9 | parser = argparse.ArgumentParser( 10 | description="RewardAnything CLI", 11 | formatter_class=argparse.RawDescriptionHelpFormatter 12 | ) 13 | 14 | subparsers = parser.add_subparsers(dest='command', help='Available commands') 15 | 16 | # Serve command 17 | serve_parser = subparsers.add_parser('serve', help='Start RewardAnything server') 18 | serve_parser.add_argument("-c", "--config", required=True, help="Path to configuration file") 19 | serve_parser.add_argument("--port", type=int, default=8000, help="Port to listen on") 20 | serve_parser.add_argument("--host", default="0.0.0.0", help="Host to bind to") 21 | serve_parser.add_argument("--base-output-path", default="./outputs", 22 | help="Base directory for storing batch outputs") 23 | 24 | # Parse arguments 25 | args = parser.parse_args() 26 | 27 | if args.command == 'serve': 28 | # Set up arguments for serve module 29 | serve_args = [ 30 | '--config', args.config, 31 | '--port', str(args.port), 32 | '--host', args.host, 33 | '--base-output-path', args.base_output_path 34 | ] 35 | 36 | # Replace sys.argv with serve arguments 37 | original_argv = sys.argv.copy() 38 | sys.argv = ['rewardanything-serve'] + serve_args 39 | 40 | try: 41 | from .serve import main as serve_main 42 | serve_main() 43 | finally: 44 | # Restore original argv 45 | sys.argv = original_argv 46 | else: 47 | parser.print_help() 48 | 49 | 50 | if __name__ == "__main__": 51 | main() -------------------------------------------------------------------------------- /examples/local_usage.py: -------------------------------------------------------------------------------- 1 | import rewardanything 2 | 3 | # Load model locally (similar to HuggingFace) 4 | reward_model = rewardanything.from_pretrained( 5 | "zhuohaoyu/RewardAnything-8B-v1", # Model path/name 6 | device="cuda", # Device placement 7 | torch_dtype="auto" # Automatic dtype selection 8 | ) 9 | 10 | # Define your evaluation principle 11 | principle = "I prefer clear, concise and helpful responses over long and detailed ones." 12 | 13 | # Your evaluation data 14 | prompt = "How do I learn Python programming effectively?" 15 | responses = { 16 | "response_a": "Start with Python.org's tutorial, practice daily with small projects, and join r/learnpython for help. Focus on fundamentals first.", 17 | "response_b": "Here's a comprehensive approach: 1) Start with Python basics including variables, data types, operators, control structures like if-statements, for-loops, while-loops, and functions, 2) Practice with small projects like calculators, text games, and data manipulation scripts, 3) Use interactive platforms like Codecademy, Python.org's official tutorial, edX courses, Coursera specializations, and YouTube channels, 4) Join communities like r/learnpython, Stack Overflow, Python Discord servers, and local meetups for support and networking, 5) Build progressively complex projects including web scrapers, APIs, data analysis tools, and web applications, 6) Read books like 'Automate the Boring Stuff', 'Python Crash Course', and 'Effective Python', 7) Dedicate 1-2 hours daily for consistent progress and track your learning journey.", 18 | "response_c": "Learn Python by coding." 19 | } 20 | 21 | # Get comprehensive evaluation 22 | result = reward_model.judge( 23 | principle=principle, 24 | prompt=prompt, 25 | responses=responses 26 | ) 27 | 28 | print(f"Scores: {result.scores}") 29 | print(f"Best to worst: {result.ranking}") 30 | print(f"Reasoning: {result.reasoning}") -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Python 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *.so 6 | .Python 7 | build/ 8 | develop-eggs/ 9 | dist/ 10 | downloads/ 11 | eggs/ 12 | .eggs/ 13 | lib/ 14 | lib64/ 15 | parts/ 16 | sdist/ 17 | var/ 18 | wheels/ 19 | share/python-wheels/ 20 | *.egg-info/ 21 | .installed.cfg 22 | *.egg 23 | MANIFEST 24 | 25 | # PyInstaller 26 | *.manifest 27 | *.spec 28 | 29 | # Installer logs 30 | pip-log.txt 31 | 32 | # Unit test / coverage reports 33 | htmlcov/ 34 | .tox/ 35 | .nox/ 36 | .coverage 37 | .coverage.* 38 | .cache 39 | nosetests.xml 40 | coverage.xml 41 | *.cover 42 | *.py,cover 43 | .hypothesis/ 44 | .pytest_cache/ 45 | cover/ 46 | 47 | # Translations 48 | *.mo 49 | *.pot 50 | 51 | # Jupyter Notebook 52 | .ipynb_checkpoints 53 | 54 | # IPython 55 | profile_default/ 56 | ipython_config.py 57 | 58 | # pyenv 59 | .python-version 60 | 61 | # pipenv 62 | Pipfile.lock 63 | 64 | # poetry 65 | poetry.lock 66 | 67 | # pdm 68 | .pdm.toml 69 | 70 | # PEP 582 71 | __pypackages__/ 72 | 73 | # Environments 74 | .env 75 | .venv 76 | env/ 77 | venv/ 78 | ENV/ 79 | env.bak/ 80 | venv.bak/ 81 | 82 | # Spyder project settings 83 | .spyderproject 84 | .spyproject 85 | 86 | # Rope project settings 87 | .ropeproject 88 | 89 | # mkdocs documentation 90 | /site 91 | 92 | # mypy 93 | .mypy_cache/ 94 | .dmypy.json 95 | dmypy.json 96 | 97 | # Pyre type checker 98 | .pyre/ 99 | 100 | # pytype static type analyzer 101 | .pytype/ 102 | 103 | # Cython debug symbols 104 | cython_debug/ 105 | 106 | # RewardAnything specific 107 | # Output directories 108 | outputs/ 109 | tmp/ 110 | temp/ 111 | logs/ 112 | cache/ 113 | 114 | # Model files and checkpoints 115 | *.bin 116 | *.safetensors 117 | checkpoints/ 118 | models/ 119 | converted_ckpts/ 120 | *.ckpt 121 | *.pth 122 | *.pt 123 | 124 | # Config files with sensitive data 125 | config_real.json 126 | config_production.json 127 | *_real.json 128 | *_prod.json 129 | .secrets/ 130 | 131 | # Server outputs and batch processing 132 | responses/ 133 | all_responses.jsonl 134 | single_*/ 135 | batch_*/ 136 | 137 | # Transformers cache 138 | .cache/ 139 | transformers_cache/ 140 | 141 | # HuggingFace cache 142 | .huggingface/ 143 | 144 | # IDE and editor files 145 | .vscode/ 146 | .idea/ 147 | *.swp 148 | *.swo 149 | *~ 150 | 151 | # Linux 152 | *~ 153 | 154 | # Temporary files 155 | *.tmp 156 | *.temp 157 | *.bak 158 | *.backup 159 | 160 | # Compiled files 161 | *.pyc 162 | *.pyo 163 | *.pyd 164 | 165 | 166 | # Development and testing 167 | .tox/ 168 | .coverage 169 | htmlcov/ 170 | .pytest_cache/ 171 | test_outputs/ 172 | benchmark_results/ 173 | 174 | # Documentation builds 175 | docs/build/ 176 | docs/_build/ 177 | 178 | # Package builds 179 | dist/ 180 | build/ 181 | *.egg-info/ 182 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "rewardanything" 7 | version = "1.0.1" 8 | description = "RewardAnything: Generalizable Principle-Following Reward Models" 9 | readme = "README.md" 10 | license = {text = "Apache-2.0"} 11 | authors = [ 12 | {name = "Zhuohao Yu", email = "zhuohaoyu1228@gmail.com"}, 13 | {name = "Jiali Zeng"}, 14 | {name = "Weizheng Gu"}, 15 | {name = "Yidong Wang"}, 16 | {name = "Jindong Wang"}, 17 | {name = "Fandong Meng"}, 18 | {name = "Jie Zhou"}, 19 | {name = "Yue Zhang"}, 20 | {name = "Shikun Zhang"}, 21 | {name = "Wei Ye"} 22 | ] 23 | maintainers = [ 24 | {name = "Zhuohao Yu", email = "zhuohaoyu1228@gmail.com"} 25 | ] 26 | keywords = ["machine learning", "reward modeling", "RLHF", "principle-following", "evaluation", "LLM", "alignment"] 27 | classifiers = [ 28 | "Development Status :: 4 - Beta", 29 | "Intended Audience :: Developers", 30 | "Intended Audience :: Science/Research", 31 | "License :: OSI Approved :: Apache Software License", 32 | "Operating System :: OS Independent", 33 | "Programming Language :: Python :: 3", 34 | "Programming Language :: Python :: 3.8", 35 | "Programming Language :: Python :: 3.9", 36 | "Programming Language :: Python :: 3.10", 37 | "Programming Language :: Python :: 3.11", 38 | "Programming Language :: Python :: 3.12", 39 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 40 | "Topic :: Software Development :: Libraries :: Python Modules", 41 | ] 42 | requires-python = ">=3.8" 43 | dependencies = [ 44 | "torch>=2.0.0", 45 | "transformers>=4.51.0", 46 | "fastapi>=0.104.0", 47 | "uvicorn>=0.24.0", 48 | "pydantic>=2.0.0", 49 | "requests>=2.28.0", 50 | "numpy>=1.21.0", 51 | "scipy>=1.7.0", 52 | "tqdm>=4.64.0", 53 | "openai>=1.0.0", 54 | ] 55 | 56 | [project.optional-dependencies] 57 | server = [ 58 | "fastapi>=0.104.0", 59 | "uvicorn>=0.24.0", 60 | ] 61 | local = [ 62 | "torch>=2.0.0", 63 | "transformers>=4.51.0", 64 | ] 65 | all = [ 66 | "fastapi>=0.104.0", 67 | "uvicorn>=0.24.0", 68 | "torch>=2.0.0", 69 | "transformers>=4.51.0", 70 | ] 71 | dev = [ 72 | "pytest>=7.0.0", 73 | "pytest-asyncio>=0.21.0", 74 | "black>=23.0.0", 75 | "isort>=5.12.0", 76 | "flake8>=6.0.0", 77 | "mypy>=1.0.0", 78 | ] 79 | 80 | [project.urls] 81 | Homepage = "https://github.com/zhuohaoyu/RewardAnything" 82 | Repository = "https://github.com/zhuohaoyu/RewardAnything" 83 | Documentation = "https://github.com/zhuohaoyu/RewardAnything#readme" 84 | "Bug Tracker" = "https://github.com/zhuohaoyu/RewardAnything/issues" 85 | 86 | [project.scripts] 87 | rewardanything = "rewardanything.cli:main" 88 | 89 | [tool.setuptools.packages.find] 90 | include = ["rewardanything*"] 91 | 92 | [tool.setuptools.package-data] 93 | rewardanything = ["*.json", "*.yaml", "*.yml"] 94 | 95 | [tool.black] 96 | line-length = 100 97 | target-version = ['py38'] 98 | 99 | [tool.isort] 100 | profile = "black" 101 | line_length = 100 102 | 103 | [tool.mypy] 104 | python_version = "3.8" 105 | warn_return_any = true 106 | warn_unused_configs = true 107 | disallow_untyped_defs = true -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | import os 3 | 4 | # Read README for long description 5 | with open("README.md", "r", encoding="utf-8") as fh: 6 | long_description = fh.read() 7 | 8 | # Read requirements 9 | def read_requirements(filename): 10 | if os.path.exists(filename): 11 | with open(filename, "r", encoding="utf-8") as f: 12 | return [line.strip() for line in f if line.strip() and not line.startswith("#")] 13 | return [] 14 | 15 | # Core requirements for the package 16 | core_requirements = [ 17 | "torch>=2.0.0", 18 | "transformers>=4.51.0", 19 | "tokenizers>=0.13.0", 20 | "requests>=2.25.0", 21 | "pydantic>=1.8.0", 22 | "tqdm>=4.62.0", 23 | "numpy>=1.21.0", 24 | "accelerate>=1.7.0", 25 | ] 26 | 27 | # Server requirements 28 | server_requirements = [ 29 | "fastapi>=0.68.0", 30 | "uvicorn[standard]>=0.15.0", 31 | "httpx>=0.24.0", 32 | "openai>=1.0.0", 33 | "asyncio", 34 | ] 35 | 36 | # Development requirements 37 | dev_requirements = [ 38 | "pytest>=6.0.0", 39 | "pytest-asyncio>=0.18.0", 40 | "black>=21.0.0", 41 | "isort>=5.9.0", 42 | "flake8>=3.9.0", 43 | "mypy>=0.910", 44 | "pre-commit>=2.15.0", 45 | ] 46 | 47 | # Benchmark requirements 48 | benchmark_requirements = [ 49 | "datasets>=2.0.0", 50 | "scipy>=1.7.0", 51 | "pandas>=1.3.0", 52 | "scikit-learn>=1.0.0", 53 | ] 54 | 55 | setup( 56 | name="RewardAnything", 57 | version="1.0.1", 58 | author="Zhuohao Yu", 59 | author_email="zyu@stu.pku.edu.cn", 60 | description="RewardAnything: Generalizable Principle-Following Reward Models", 61 | long_description=long_description, 62 | long_description_content_type="text/markdown", 63 | url="https://github.com/zhuohaoyu/RewardAnything", 64 | packages=find_packages(), 65 | classifiers=[ 66 | "Development Status :: 4 - Beta", 67 | "Intended Audience :: Developers", 68 | "Intended Audience :: Science/Research", 69 | "License :: OSI Approved :: Apache Software License", 70 | "Operating System :: OS Independent", 71 | "Programming Language :: Python :: 3", 72 | "Programming Language :: Python :: 3.8", 73 | "Programming Language :: Python :: 3.9", 74 | "Programming Language :: Python :: 3.10", 75 | "Programming Language :: Python :: 3.11", 76 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 77 | "Topic :: Software Development :: Libraries :: Python Modules", 78 | ], 79 | python_requires=">=3.8", 80 | install_requires=core_requirements, 81 | extras_require={ 82 | "server": server_requirements, 83 | "dev": dev_requirements, 84 | "benchmarks": benchmark_requirements, 85 | "all": server_requirements + dev_requirements + benchmark_requirements, 86 | }, 87 | entry_points={ 88 | "console_scripts": [ 89 | "rewardanything=rewardanything.cli:main", 90 | "rewardanything-serve=rewardanything.serve:main", 91 | ], 92 | }, 93 | include_package_data=True, 94 | package_data={ 95 | "rewardanything": ["*.json", "*.yaml", "*.txt"], 96 | }, 97 | keywords="reward model, RLHF, language model, evaluation, principle-following", 98 | project_urls={ 99 | "Bug Reports": "https://github.com/zhuohaoyu/RewardAnything/issues", 100 | "Source": "https://github.com/zhuohaoyu/RewardAnything", 101 | "Documentation": "https://rewardanything.readthedocs.io/", 102 | "Paper": "https://arxiv.org/abs/2506.03637", 103 | }, 104 | ) -------------------------------------------------------------------------------- /rewardanything/local.py: -------------------------------------------------------------------------------- 1 | """Local inference implementation for RewardAnything models.""" 2 | 3 | import torch 4 | from typing import Dict, List, Optional, Union, Any 5 | from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig 6 | from .models import RewardResult 7 | from .processing import prepare_chat_messages, parse_rewardanything_output 8 | import logging 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | class RewardModel: 14 | """Local RewardAnything model for principle-following evaluation.""" 15 | 16 | def __init__( 17 | self, 18 | model, 19 | tokenizer, 20 | generation_config: Optional[GenerationConfig] = None, 21 | device: Optional[str] = None 22 | ): 23 | self.model = model 24 | self.tokenizer = tokenizer 25 | self.generation_config = generation_config or GenerationConfig( 26 | max_new_tokens=4096, 27 | temperature=0.1, 28 | do_sample=True, 29 | top_p=0.9, 30 | pad_token_id=tokenizer.eos_token_id if tokenizer.eos_token_id else 0 31 | ) 32 | self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") 33 | 34 | def judge( 35 | self, 36 | principle: str, 37 | prompt: str, 38 | responses: Dict[str, str], 39 | mask_responses: bool = True, 40 | **generation_kwargs 41 | ) -> RewardResult: 42 | """ 43 | Evaluate responses based on a natural language principle. 44 | 45 | Args: 46 | principle: Natural language principle for evaluation 47 | prompt: The input prompt that responses are answering 48 | responses: Dict mapping model names to their responses 49 | mask_responses: Whether to mask model names during evaluation 50 | **generation_kwargs: Additional generation parameters 51 | 52 | Returns: 53 | RewardResult containing scores, ranking, and reasoning 54 | """ 55 | # Prepare chat messages using unified processing 56 | messages, masked2real = prepare_chat_messages( 57 | principle, prompt, responses, mask_responses 58 | ) 59 | 60 | # Format for the model 61 | formatted_input = self.tokenizer.apply_chat_template( 62 | messages, tokenize=False, add_generation_prompt=True 63 | ) 64 | 65 | # Tokenize input 66 | inputs = self.tokenizer( 67 | formatted_input, 68 | return_tensors="pt", 69 | padding=True, 70 | truncation=True, 71 | max_length=4096 72 | ).to(self.device) 73 | 74 | # Generate response 75 | generation_config = self.generation_config 76 | if generation_kwargs: 77 | generation_config = GenerationConfig(**{ 78 | **self.generation_config.to_dict(), 79 | **generation_kwargs 80 | }) 81 | 82 | with torch.no_grad(): 83 | outputs = self.model.generate( 84 | **inputs, 85 | generation_config=generation_config, 86 | pad_token_id=self.tokenizer.eos_token_id 87 | ) 88 | 89 | # Decode output 90 | generated_tokens = outputs[0][inputs.input_ids.shape[1]:] 91 | output_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True) 92 | 93 | # Parse output using unified processing 94 | return parse_rewardanything_output(output_text, masked2real) 95 | 96 | def judge_batch( 97 | self, 98 | requests: List[Dict[str, Any]], 99 | batch_size: int = 8, 100 | **generation_kwargs 101 | ) -> List[RewardResult]: 102 | """ 103 | Evaluate multiple requests in batches. 104 | 105 | Args: 106 | requests: List of dicts with 'principle', 'prompt', 'responses' keys 107 | batch_size: Batch size for processing 108 | **generation_kwargs: Additional generation parameters 109 | 110 | Returns: 111 | List of RewardResult objects 112 | """ 113 | results = [] 114 | for i in range(0, len(requests), batch_size): 115 | batch = requests[i:i + batch_size] 116 | for request in batch: 117 | result = self.judge( 118 | principle=request["principle"], 119 | prompt=request["prompt"], 120 | responses=request["responses"], 121 | mask_responses=request.get("mask_responses", True), 122 | **generation_kwargs 123 | ) 124 | results.append(result) 125 | return results 126 | 127 | 128 | def from_pretrained( 129 | model_name_or_path: str, 130 | device: Optional[str] = None, 131 | torch_dtype: Optional[Union[str, torch.dtype]] = None, 132 | trust_remote_code: bool = False, 133 | generation_config: Optional[Dict[str, Any]] = None, 134 | **kwargs 135 | ) -> RewardModel: 136 | """ 137 | Load a RewardAnything model for local inference. 138 | 139 | Args: 140 | model_name_or_path: Path to model or HuggingFace model identifier 141 | device: Device to load model on ('cuda', 'cpu', 'auto') 142 | torch_dtype: Data type for model weights 143 | trust_remote_code: Whether to trust remote code 144 | generation_config: Generation configuration parameters 145 | **kwargs: Additional arguments passed to AutoModelForCausalLM.from_pretrained 146 | 147 | Returns: 148 | RewardModel instance ready for evaluation 149 | """ 150 | if device is None: 151 | device = "cuda" if torch.cuda.is_available() else "cpu" 152 | 153 | if torch_dtype == "auto": 154 | torch_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 155 | elif isinstance(torch_dtype, str): 156 | torch_dtype = getattr(torch, torch_dtype) 157 | 158 | # Load tokenizer 159 | tokenizer = AutoTokenizer.from_pretrained( 160 | model_name_or_path, 161 | trust_remote_code=trust_remote_code, 162 | **{k: v for k, v in kwargs.items() if k in ['use_fast', 'padding_side']} 163 | ) 164 | 165 | if tokenizer.pad_token is None: 166 | tokenizer.pad_token = tokenizer.eos_token 167 | 168 | # Load model 169 | model = AutoModelForCausalLM.from_pretrained( 170 | model_name_or_path, 171 | torch_dtype=torch_dtype, 172 | device_map=device if device != "auto" else "auto", 173 | trust_remote_code=trust_remote_code, 174 | **{k: v for k, v in kwargs.items() if k not in ['use_fast', 'padding_side']} 175 | ) 176 | 177 | # Create generation config 178 | gen_config = None 179 | if generation_config: 180 | gen_config = GenerationConfig(**generation_config) 181 | 182 | return RewardModel( 183 | model=model, 184 | tokenizer=tokenizer, 185 | generation_config=gen_config, 186 | device=device 187 | ) -------------------------------------------------------------------------------- /rewardanything/client.py: -------------------------------------------------------------------------------- 1 | """Remote client implementation for RewardAnything.""" 2 | 3 | import json 4 | import time 5 | import requests 6 | from typing import Dict, List, Optional, Any, Union 7 | from .models import RewardResult, RewardRequest, RewardResponse 8 | import logging 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | class Client: 14 | """Remote client for RewardAnything API.""" 15 | 16 | def __init__( 17 | self, 18 | base_url: str, 19 | api_key: Optional[str] = None, 20 | timeout: float = 30.0, 21 | max_retries: int = 3, 22 | headers: Optional[Dict[str, str]] = None, 23 | **kwargs 24 | ): 25 | """ 26 | Initialize RewardAnything client. 27 | 28 | Args: 29 | base_url: Base URL for the RewardAnything API 30 | api_key: Optional API key for authentication 31 | timeout: Request timeout in seconds 32 | max_retries: Maximum number of retry attempts 33 | headers: Additional headers to include in requests 34 | **kwargs: Additional client configuration 35 | """ 36 | self.base_url = base_url.rstrip('/') 37 | self.api_key = api_key 38 | self.timeout = timeout 39 | self.max_retries = max_retries 40 | 41 | self.headers = { 42 | "Content-Type": "application/json", 43 | "User-Agent": "RewardAnything-Python-Client/1.0.1" 44 | } 45 | 46 | if api_key: 47 | self.headers["Authorization"] = f"Bearer {api_key}" 48 | 49 | if headers: 50 | self.headers.update(headers) 51 | 52 | # Store additional config 53 | self.config = kwargs 54 | 55 | def _make_request( 56 | self, 57 | endpoint: str, 58 | data: Dict[str, Any], 59 | timeout: Optional[float] = None 60 | ) -> Dict[str, Any]: 61 | """Make HTTP request with retries.""" 62 | url = f"{self.base_url}{endpoint}" 63 | timeout = timeout or self.timeout 64 | 65 | last_exception = None 66 | for attempt in range(self.max_retries + 1): 67 | try: 68 | response = requests.post( 69 | url, 70 | json=data, 71 | headers=self.headers, 72 | timeout=timeout 73 | ) 74 | response.raise_for_status() 75 | return response.json() 76 | 77 | except requests.exceptions.RequestException as e: 78 | last_exception = e 79 | if attempt < self.max_retries: 80 | wait_time = 2 ** attempt # Exponential backoff 81 | logger.warning(f"Request failed (attempt {attempt + 1}), retrying in {wait_time}s: {e}") 82 | time.sleep(wait_time) 83 | else: 84 | logger.error(f"Request failed after {self.max_retries + 1} attempts: {e}") 85 | 86 | raise last_exception 87 | 88 | def judge( 89 | self, 90 | principle: str, 91 | prompt: str, 92 | responses: Dict[str, str], 93 | mask_responses: bool = True, 94 | timeout: Optional[float] = None, 95 | **kwargs 96 | ) -> RewardResult: 97 | """ 98 | Evaluate responses based on a natural language principle. 99 | 100 | Args: 101 | principle: Natural language principle for evaluation 102 | prompt: The input prompt that responses are answering 103 | responses: Dict mapping model names to their responses 104 | mask_responses: Whether to mask model names during evaluation 105 | timeout: Request timeout override 106 | **kwargs: Additional request parameters 107 | 108 | Returns: 109 | RewardResult containing scores, ranking, and reasoning 110 | """ 111 | request_data = { 112 | "principle": principle, 113 | "prompt": prompt, 114 | "responses": responses, 115 | "mask_responses": mask_responses 116 | } 117 | 118 | # Add any additional config 119 | request_data.update(kwargs) 120 | 121 | try: 122 | response_data = self._make_request( 123 | "/api/rewardanything", 124 | request_data, 125 | timeout 126 | ) 127 | 128 | # Parse response 129 | thoughts = response_data.get("thoughts", "") 130 | results = response_data.get("results", {}) 131 | 132 | return RewardResult( 133 | reasoning=thoughts, 134 | scores=results.get("scores", {}), 135 | ranking=results.get("best-to-worst", []), 136 | raw_output=json.dumps(response_data) 137 | ) 138 | 139 | except Exception as e: 140 | logger.error(f"Failed to evaluate with principle '{principle}': {e}") 141 | raise 142 | 143 | def judge_batch( 144 | self, 145 | requests: List[Dict[str, Any]], 146 | timeout: Optional[float] = None, 147 | **kwargs 148 | ) -> List[RewardResult]: 149 | """ 150 | Evaluate multiple requests in a batch. 151 | 152 | Args: 153 | requests: List of dicts with 'principle', 'prompt', 'responses' keys 154 | timeout: Request timeout override 155 | **kwargs: Additional request parameters 156 | 157 | Returns: 158 | List of RewardResult objects 159 | """ 160 | # Convert to RewardRequest format 161 | batch_requests = [] 162 | for req in requests: 163 | batch_requests.append({ 164 | "principle": req["principle"], 165 | "prompt": req["prompt"], 166 | "responses": req["responses"], 167 | "mask_responses": req.get("mask_responses", True) 168 | }) 169 | 170 | try: 171 | response_data = self._make_request( 172 | "/api/rewardanything_batch", 173 | batch_requests, 174 | timeout or (self.timeout * len(requests)) # Scale timeout with batch size 175 | ) 176 | 177 | # Parse batch response 178 | results = [] 179 | for item in response_data: 180 | results.append(RewardResult( 181 | reasoning=item.get("thoughts", ""), 182 | scores=item.get("results", {}).get("scores", {}), 183 | ranking=item.get("results", {}).get("best-to-worst", []), 184 | raw_output=json.dumps(item) 185 | )) 186 | 187 | return results 188 | 189 | except Exception as e: 190 | logger.error(f"Failed to evaluate batch of {len(requests)} requests: {e}") 191 | raise 192 | 193 | def health_check(self) -> bool: 194 | """Check if the server is healthy.""" 195 | try: 196 | response = requests.get( 197 | f"{self.base_url}/health", 198 | headers=self.headers, 199 | timeout=5.0 200 | ) 201 | return response.status_code == 200 202 | except: 203 | return False -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity granting the License. 13 | 14 | "Legal Entity" shall mean the union of the acting entity and all 15 | other entities that control, are controlled by, or are under common 16 | control with that entity. For the purposes of this definition, 17 | "control" means (i) the power, direct or indirect, to cause the 18 | direction or management of such entity, whether by contract or 19 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 20 | outstanding shares, or (iii) beneficial ownership of such entity. 21 | 22 | "You" (or "Your") shall mean an individual or Legal Entity 23 | exercising permissions granted by this License. 24 | 25 | "Source" form shall mean the preferred form for making modifications, 26 | including but not limited to software source code, documentation 27 | source, and configuration files. 28 | 29 | "Object" form shall mean any form resulting from mechanical 30 | transformation or translation of a Source form, including but 31 | not limited to compiled object code, generated documentation, 32 | and conversions to other media types. 33 | 34 | "Work" shall mean the work of authorship, whether in Source or 35 | Object form, made available under the License, as indicated by a 36 | copyright notice that is included in or attached to the work 37 | (which shall not include communications that are clearly marked or 38 | otherwise designated in writing by the owner as "Not a Work of the License"). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based upon (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and derivative works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control 57 | systems, and issue tracking systems that are managed by, or on behalf 58 | of, the Licensor for the purpose of discussing and improving the Work, 59 | but excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to use, reproduce, modify, distribute, and prepare 70 | Derivative Works of, and to display and perform the Work and such Derivative 71 | Works in any medium or format, whether now known or hereafter devised, 72 | provided that You preserve all copyright, notice, and attribution 73 | statements. 74 | 75 | 3. Grant of Patent License. Subject to the terms and conditions of 76 | this License, each Contributor hereby grants to You a perpetual, 77 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 78 | (except as stated in this section) patent license to make, have made, 79 | use, offer to sell, sell, import, and otherwise transfer the Work, 80 | where such license applies only to those patent claims licensable 81 | by such Contributor that are necessarily infringed by their 82 | Contribution(s) alone or by combination of their Contribution(s) 83 | with the Work to which such Contribution(s) was submitted. 84 | 85 | 4. Redistribution. You must give any other recipients of the Work or 86 | Derivative Works a copy of this License; and You must cause any 87 | modified files to carry prominent notices stating that You changed 88 | the files; and You must retain, in the Source form of any Derivative 89 | Works that You distribute, all copyright, trademark, patent, 90 | attribution and other notices from the Source form of the Work, 91 | excluding those notices that do not pertain to any part of 92 | the Derivative Works; and If the Work includes a "NOTICE" text file 93 | as part of its distribution, then any Derivative Works that You 94 | distribute must include a readable copy of the attribution notices 95 | contained within such NOTICE file, excluding those notices that do not 96 | pertain to any part of the Derivative Works, in at least one of the 97 | following places: within a NOTICE text file distributed as part of the 98 | Derivative Works; within the Source form or documentation, if provided 99 | along with the Derivative Works; or, within a display generated by the 100 | Derivative Works, if and wherever such third-party notices normally appear. 101 | 102 | 5. Submission of Contributions. Unless You explicitly state otherwise, 103 | any Contribution intentionally submitted for inclusion in the Work 104 | by You to the Licensor shall be under the terms and conditions of 105 | this License, without any additional terms or conditions. 106 | Notwithstanding the above, nothing herein shall supersede or modify 107 | the terms of any separate license agreement you may have executed 108 | with Licensor regarding such Contributions. 109 | 110 | 6. Trademarks. This License does not grant permission to use the trade 111 | names, trademarks, service marks, or product names of the Licensor, 112 | except as required for reasonable and customary use in describing the 113 | origin of the Work and reproducing the content of the NOTICE file. 114 | 115 | 7. Disclaimer of Warranty. Unless required by applicable law or 116 | agreed to in writing, Licensor provides the Work (and each 117 | Contributor provides its Contributions) on an "AS IS" BASIS, 118 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 119 | implied, including, without limitation, any warranties or conditions 120 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 121 | PARTICULAR PURPOSE. You are solely responsible for determining the 122 | appropriateness of using or redistributing the Work and assume any 123 | risks associated with Your exercise of permissions under this License. 124 | 125 | 8. Limitation of Liability. In no event and under no legal theory, 126 | whether in tort (including negligence), contract, or otherwise, 127 | unless required by applicable law (such as deliberate and grossly 128 | negligent acts) or agreed to in writing, shall any Contributor be 129 | liable to You for damages, including any direct, indirect, special, 130 | incidental, or consequential damages of any character arising as a 131 | result of this License or out of the use or inability to use the 132 | Work (including but not limited to damages for loss of goodwill, 133 | work stoppage, computer failure or malfunction, or any and all 134 | other commercial damages or losses), even if such Contributor 135 | has been advised of the possibility of such damages. 136 | 137 | 9. Accepting Warranty or Support. When redistributing the Work or 138 | Derivative Works thereof, You may choose to offer, and charge a fee 139 | for, acceptance of support, warranty, indemnity, or other liability 140 | obligations and/or rights consistent with this License. However, in 141 | accepting such obligations, You may act only on Your own behalf and on 142 | Your sole responsibility, not on behalf of any other Contributor, and 143 | only if You agree to indemnify, defend, and hold each Contributor 144 | harmless for any liability incurred by, or claims asserted against, 145 | such Contributor by reason of your accepting any such warranty or support. 146 | 147 | END OF TERMS AND CONDITIONS 148 | 149 | Copyright 2025 RewardAnything Contributors 150 | 151 | Licensed under the Apache License, Version 2.0 (the "License"); 152 | you may not use this file except in compliance with the License. 153 | You may obtain a copy of the License at 154 | 155 | http://www.apache.org/licenses/LICENSE-2.0 156 | 157 | Unless required by applicable law or agreed to in writing, software 158 | distributed under the License is distributed on an "AS IS" BASIS, 159 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 160 | See the License for the specific language governing permissions and 161 | limitations under the License. -------------------------------------------------------------------------------- /rewardanything/utils.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import json 3 | import os 4 | import argparse 5 | import logging 6 | import time 7 | import codecs 8 | import traceback 9 | 10 | from typing import Optional, List, Dict, Union, Any 11 | from random import choice 12 | from tqdm.asyncio import tqdm as atqdm 13 | from openai import AsyncOpenAI, APIStatusError 14 | 15 | 16 | class OpenAIClient: 17 | def __init__( 18 | self, 19 | output_path: str, 20 | api_model: str, 21 | api_key: Union[str, List[str]], 22 | api_base: Optional[Union[str, List[str]]] = None, 23 | api_proxy: Optional[Union[str, List[str]]] = None, 24 | api_timeout: Optional[float] = 60.0, 25 | api_max_retries: Optional[int] = 5, 26 | generation_config: Optional[Dict] = None, 27 | max_error_count: Optional[int] = 100, 28 | trial_run=False, 29 | dump_individual_rsp=True, 30 | ): 31 | self.output_path = output_path 32 | self.trial_run = trial_run 33 | self.max_error_count = max_error_count 34 | self.total_errors = 0 35 | self.logger = logging.getLogger(__name__) 36 | 37 | if isinstance(api_key, str): 38 | api_key = [api_key] 39 | 40 | if api_base is None: 41 | api_base = ["https://api.openai.com/v1"] * len(api_key) 42 | elif isinstance(api_base, str): 43 | api_base = [api_base] * len(api_key) 44 | 45 | assert len(api_key) == len( 46 | api_base 47 | ), "Number of api_key and api_base must match" 48 | 49 | if api_proxy is not None: 50 | if isinstance(api_proxy, str): 51 | api_proxy = [api_proxy] * len(api_key) 52 | else: 53 | assert len(api_proxy) == len( 54 | api_key 55 | ), "Number of api_key and api_proxy must match" 56 | self.clients = [ 57 | AsyncOpenAI( 58 | api_key=key, 59 | base_url=api_base, 60 | timeout=api_timeout, 61 | max_retries=api_max_retries, 62 | ) 63 | for key, api_base, proxy in zip( 64 | api_key, api_base, api_proxy 65 | ) 66 | ] 67 | else: 68 | self.clients = [ 69 | AsyncOpenAI( 70 | api_key=key, 71 | base_url=api_base, 72 | timeout=api_timeout, 73 | max_retries=api_max_retries, 74 | ) 75 | for key, api_base in zip(api_key, api_base) 76 | ] 77 | 78 | self.model = api_model 79 | 80 | self.response_queue = asyncio.Queue() 81 | self.dump_individual_rsp = dump_individual_rsp 82 | 83 | if generation_config is None: 84 | self.generation_config = { 85 | "frequency_penalty": 0, 86 | "max_tokens": 100, 87 | "n": 1, 88 | "presence_penalty": 0, 89 | "response_format": {"type": "text"}, 90 | "seed": 42, 91 | "stream": False, 92 | "temperature": 0.0, 93 | } 94 | else: 95 | self.generation_config = generation_config 96 | 97 | if dump_individual_rsp: 98 | os.makedirs(os.path.join(self.output_path, "responses"), exist_ok=True) 99 | 100 | async def query( 101 | self, 102 | request, 103 | num_retries=3, 104 | ): 105 | if isinstance(request, dict): 106 | request_dict = request 107 | else: 108 | request_dict = request.__dict__ 109 | 110 | assert "messages" in request_dict, "messages must be provided in request" 111 | assert "uuid" in request_dict, "uuid must be provided in request" 112 | 113 | if self.dump_individual_rsp: 114 | save_path = os.path.join( 115 | self.output_path, "responses", f'{request_dict["uuid"]}.json' 116 | ) 117 | 118 | if os.path.exists(save_path) and not self.trial_run: 119 | with codecs.open(save_path) as f: 120 | rsp_content = json.load(f) 121 | await self.response_queue.put(rsp_content) 122 | return f"Skipping {save_path}" 123 | 124 | if "generation_config" in request_dict and isinstance( 125 | request_dict["generation_config"], dict 126 | ): 127 | generation_config = self.generation_config.copy() 128 | generation_config.update(request_dict["generation_config"]) 129 | else: 130 | generation_config = self.generation_config 131 | 132 | response = None 133 | while num_retries > 0: 134 | num_retries -= 1 135 | try: 136 | client = choice(self.clients) 137 | response = await client.chat.completions.create( 138 | messages=request_dict["messages"], 139 | model=self.model, 140 | **generation_config, 141 | ) 142 | response = response.model_dump() 143 | break 144 | except APIStatusError as e: 145 | if self.max_error_count > self.total_errors: 146 | self.total_errors += 1 147 | self.logger.warning( 148 | f"OpenAI APIStatusError: {e}, total errors: {self.total_errors}, sleeping..." 149 | ) 150 | await asyncio.sleep(1.0) 151 | else: 152 | self.logger.error( 153 | f"OpenAI APIStatusError: {e}, max_error_count reached, exiting..." 154 | ) 155 | raise e 156 | except: 157 | print(traceback.format_exc()) 158 | 159 | if response is None: 160 | raise Exception("Empty response from remote OpenAI API") 161 | 162 | try: 163 | response["generated_text"] = response["choices"][0]["message"]["content"] 164 | except: 165 | print(traceback.format_exc()) 166 | print(response) 167 | raise Exception("Empty response from remote OpenAI API") 168 | 169 | if self.dump_individual_rsp: 170 | with codecs.open(save_path, "w") as f: 171 | json.dump( 172 | {"request": request_dict, "response": response}, 173 | f, 174 | ensure_ascii=False, 175 | indent=2, 176 | ) 177 | 178 | await self.response_queue.put({"request": request_dict, "response": response}) 179 | 180 | return response["choices"][0]["message"]["content"] 181 | 182 | async def write_responses_to_file(self): 183 | save_path = os.path.join(self.output_path, "all_responses.jsonl") 184 | while True: 185 | response = await self.response_queue.get() 186 | with codecs.open(save_path, "a", encoding="utf-8") as f: 187 | f.write(json.dumps(response, ensure_ascii=False) + "\n") 188 | self.response_queue.task_done() 189 | 190 | 191 | class AsyncRateLimitThreadPool: 192 | def __init__(self, num_workers, num_requests, period): 193 | self.num_workers = num_workers 194 | self.num_requests = num_requests 195 | self.loop = asyncio.get_event_loop() 196 | self.semaphore = asyncio.Semaphore(num_workers) 197 | self.last_call_time = time.time() 198 | self.call_count = 0 199 | self.period = period 200 | 201 | async def __aenter__(self): 202 | return self 203 | 204 | async def __aexit__(self, exc_type, exc, tb): 205 | pass 206 | 207 | async def _rate_limited_call(self, func, *args, **kwargs): 208 | # Limit the number of calls to func per minute 209 | elapsed_time = time.time() - self.last_call_time 210 | if elapsed_time < self.period: 211 | self.call_count += 1 212 | if self.call_count > self.num_requests: 213 | sleep_time = self.period - elapsed_time 214 | # logging.info("Sleeping for {} seconds".format(sleep_time)) 215 | await asyncio.sleep(sleep_time) 216 | self.call_count = 0 217 | self.last_call_time = time.time() 218 | 219 | # Acquire a semaphore permit before calling func 220 | async with self.semaphore: 221 | result = await func(*args, **kwargs) 222 | 223 | return result 224 | 225 | async def map(self, func, *args_list): 226 | coroutines = [self._rate_limited_call(func, *args) for args in zip(*args_list)] 227 | 228 | # Use tqdm progress bar with coroutines 229 | results = [] 230 | for coroutine in atqdm.as_completed(coroutines): 231 | result = await coroutine 232 | results.append(result) 233 | 234 | return results 235 | 236 | 237 | async def run_pool(api, requests, num_workers, num_requests, period): 238 | pool = AsyncRateLimitThreadPool(num_workers, num_requests, period) 239 | writer_task = asyncio.create_task(api.write_responses_to_file()) 240 | 241 | results = await pool.map(api.query, requests) 242 | await api.response_queue.join() # Ensure all responses are written 243 | writer_task.cancel() 244 | 245 | return results 246 | 247 | 248 | def run_api_inference( 249 | requests: Union[ 250 | List[Dict], Any 251 | ], # can List[Dict] or list of any object with __dict__ attribute 252 | output_path: str, # path to save responses 253 | api_model: str, # openai model name 254 | api_key: Union[str, List[str]], 255 | api_base: Optional[Union[str, List[str]]] = None, 256 | api_proxy: Optional[Union[str, List[str]]] = None, 257 | api_timeout: Optional[float] = 30.0, 258 | api_max_retries: Optional[int] = 5, 259 | generation_config: Optional[Dict] = None, 260 | num_workers: Optional[int] = 8, 261 | request_limit: Optional[int] = 100, 262 | request_limit_period: Optional[int] = 60, 263 | max_error_count: Optional[int] = 100, 264 | trial_run=False, 265 | dump_individual_rsp=True, 266 | ): 267 | logging.getLogger(__name__).info( 268 | f"num_requests: {len(requests)}, output_path: {output_path}" 269 | ) 270 | logging.getLogger("httpx").setLevel(logging.WARNING) 271 | 272 | os.makedirs(output_path, exist_ok=True) 273 | 274 | if dump_individual_rsp: 275 | os.makedirs(os.path.join(output_path, "responses"), exist_ok=True) 276 | 277 | if os.path.exists(os.path.join(output_path, "all_responses.jsonl")): 278 | os.remove(os.path.join(output_path, "all_responses.jsonl")) 279 | 280 | client = OpenAIClient( 281 | output_path=output_path, 282 | api_model=api_model, 283 | api_key=api_key, 284 | api_base=api_base, 285 | api_proxy=api_proxy, 286 | api_timeout=api_timeout, 287 | api_max_retries=api_max_retries, 288 | generation_config=generation_config, 289 | trial_run=trial_run, 290 | dump_individual_rsp=dump_individual_rsp, 291 | ) 292 | 293 | try: 294 | asyncio.run( 295 | run_pool( 296 | client, 297 | requests, 298 | num_workers=num_workers, 299 | num_requests=request_limit, 300 | period=request_limit_period, 301 | ) 302 | ) 303 | except KeyboardInterrupt: 304 | logging.getLogger(__name__).info("Interrupt received! Closing...") 305 | -------------------------------------------------------------------------------- /docs/PROJECT_DOCS.md: -------------------------------------------------------------------------------- 1 | # RewardAnything Project Documentation 2 | 3 | ## Overview 4 | 5 | RewardAnything is a revolutionary reward modeling framework that enables models to understand and follow explicit natural language principles instead of learning implicit preferences from fixed datasets. This enables dynamic adaptation to diverse evaluation criteria without costly retraining. 6 | 7 | ## Project Structure 8 | 9 | ``` 10 | rewardanything/ 11 | ├── __init__.py # Package initialization 12 | ├── models.py # Data models and result classes 13 | ├── local.py # Local inference implementation 14 | ├── client.py # Remote client implementation 15 | ├── serve.py # FastAPI server implementation 16 | ├── cli.py # Command-line interface 17 | ├── utils.py # Utility functions (OpenAI client, rate limiting) 18 | └── benchmarks.py # Benchmark evaluation tools (optional) 19 | 20 | configs/ 21 | └── server_config.json # Example server configuration 22 | 23 | docs/ 24 | ├── PROJECT_DOCS.md # This file 25 | ├── API_REFERENCE.md # Detailed API documentation 26 | └── DEPLOYMENT_GUIDE.md # Production deployment guide 27 | 28 | tests/ 29 | ├── test_local.py # Local inference tests 30 | ├── test_client.py # Remote client tests 31 | └── test_server.py # Server functionality tests 32 | 33 | examples/ 34 | ├── basic_usage.py # Basic usage examples 35 | ├── batch_evaluation.py # Batch processing examples 36 | └── custom_principles.py # Advanced principle examples 37 | ``` 38 | 39 | ## Core Components 40 | 41 | ### 1. Local Inference (`local.py`) 42 | 43 | The local inference module provides direct model loading and evaluation: 44 | 45 | ```python 46 | import rewardanything 47 | 48 | # Load model locally 49 | reward_model = rewardanything.from_pretrained( 50 | "RewardAnything/RewardAnything-8B", 51 | device="cuda", 52 | torch_dtype="auto" 53 | ) 54 | 55 | # Evaluate responses 56 | result = reward_model.judge( 57 | principle="Prefer concise, accurate responses", 58 | prompt="What is Python?", 59 | responses={ 60 | "model_a": "Python is a programming language...", 61 | "model_b": "Python is a snake." 62 | } 63 | ) 64 | ``` 65 | 66 | **Key Features:** 67 | - Direct model loading from HuggingFace 68 | - GPU/CPU support with automatic device detection 69 | - Batch processing capabilities 70 | - Customizable generation parameters 71 | - Response masking to prevent bias 72 | 73 | ### 2. Remote Client (`client.py`) 74 | 75 | The remote client enables interaction with RewardAnything servers: 76 | 77 | ```python 78 | import rewardanything 79 | 80 | # Connect to server 81 | client = rewardanything.Client( 82 | base_url="http://localhost:8000", 83 | api_key="your-api-key", # Optional 84 | timeout=30.0 85 | ) 86 | 87 | # Same API as local inference 88 | result = client.judge( 89 | principle="Prioritize safety and helpfulness", 90 | prompt="How to learn programming?", 91 | responses=responses 92 | ) 93 | ``` 94 | 95 | **Key Features:** 96 | - HTTP-based communication 97 | - Automatic retry with exponential backoff 98 | - Authentication support 99 | - Batch processing 100 | - Health check capabilities 101 | 102 | ### 3. Server Implementation (`serve.py`) 103 | 104 | The server provides a FastAPI-based REST API for RewardAnything: 105 | 106 | ```bash 107 | # Start server 108 | rewardanything-serve -c configs/server_config.json --port 8000 109 | ``` 110 | 111 | **API Endpoints:** 112 | - `POST /api/rewardanything` - Single evaluation 113 | - `POST /api/rewardanything_batch` - Batch evaluation 114 | - `POST /api/new_batch_request` - Async batch processing 115 | - `GET /api/fetch_results/{batch_id}` - Retrieve batch results 116 | - `GET /health` - Health check 117 | 118 | ### 4. Data Models (`models.py`) 119 | 120 | Core data structures for the framework: 121 | 122 | ```python 123 | @dataclass 124 | class RewardResult: 125 | reasoning: str # Model's reasoning process 126 | scores: Dict[str, float] # Model scores (1-5) 127 | ranking: List[str] # Best to worst ranking 128 | raw_output: Optional[str] = None # Raw model output 129 | 130 | class RewardRequest(BaseModel): 131 | principle: str # Evaluation principle 132 | prompt: str # Input prompt 133 | responses: Dict[str, str] # Model responses 134 | mask_responses: bool = True # Whether to mask model names 135 | ``` 136 | 137 | ## Installation and Setup 138 | 139 | ### Basic Installation 140 | 141 | ```bash 142 | pip install RewardAnything 143 | ``` 144 | 145 | ### Development Installation 146 | 147 | ```bash 148 | git clone https://github.com/zhuohaoyu/RewardAnything.git 149 | cd RewardAnything 150 | pip install -e ".[dev]" 151 | ``` 152 | 153 | ### Server Installation 154 | 155 | ```bash 156 | pip install "RewardAnything[server]" 157 | ``` 158 | 159 | ### Full Installation 160 | 161 | ```bash 162 | pip install "RewardAnything[all]" 163 | ``` 164 | 165 | ## Usage Patterns 166 | 167 | ### 1. Research and Experimentation 168 | 169 | For research use cases, local inference is recommended: 170 | 171 | ```python 172 | import rewardanything 173 | 174 | # Load model with specific configuration 175 | model = rewardanything.from_pretrained( 176 | "RewardAnything/RewardAnything-8B", 177 | device="cuda", 178 | torch_dtype="bfloat16", 179 | generation_config={ 180 | "temperature": 0.1, 181 | "max_new_tokens": 2048 182 | } 183 | ) 184 | 185 | # Evaluate with complex principles 186 | principle = """ 187 | Evaluate responses based on: 188 | 1. Factual accuracy (50% weight) 189 | 2. Clarity and structure (30% weight) 190 | 3. Engagement and tone (20% weight) 191 | """ 192 | 193 | result = model.judge(principle, prompt, responses) 194 | ``` 195 | 196 | ### 2. Production Deployment 197 | 198 | For production use cases, use the server: 199 | 200 | ```bash 201 | # Start server 202 | rewardanything-serve -c production_config.json --port 8000 203 | 204 | # Scale with load balancer and multiple instances 205 | # Use Docker for containerization 206 | ``` 207 | 208 | ```python 209 | # Client usage in production 210 | client = rewardanything.Client("https://api.yourservice.com/v1") 211 | results = client.judge_batch(evaluation_requests) 212 | ``` 213 | 214 | ### 3. RLHF Integration 215 | 216 | Integration with reinforcement learning from human feedback: 217 | 218 | ```python 219 | def reward_function(prompt, response): 220 | principle = "Reward helpful, harmless, and honest responses" 221 | result = reward_model.judge( 222 | principle=principle, 223 | prompt=prompt, 224 | responses={"candidate": response} 225 | ) 226 | return result.scores["candidate"] 227 | 228 | # Use in PPO/GRPO training loops 229 | ``` 230 | 231 | ## Configuration 232 | 233 | ### Local Model Configuration 234 | 235 | ```python 236 | model = rewardanything.from_pretrained( 237 | model_name_or_path="RewardAnything/RewardAnything-8B", 238 | device="cuda", # Device placement 239 | torch_dtype="auto", # Automatic dtype selection 240 | trust_remote_code=True, # Trust remote code 241 | generation_config={ # Generation parameters 242 | "max_new_tokens": 2048, 243 | "temperature": 0.1, 244 | "do_sample": True, 245 | "top_p": 0.9 246 | } 247 | ) 248 | ``` 249 | 250 | ### Server Configuration 251 | 252 | ```json 253 | { 254 | "api_model": "gpt-4-turbo-preview", 255 | "api_key": ["key1", "key2"], 256 | "api_base": ["https://api.openai.com/v1"], 257 | "generation_config": { 258 | "max_tokens": 2048, 259 | "temperature": 0.1, 260 | "frequency_penalty": 0, 261 | "presence_penalty": 0 262 | }, 263 | "num_workers": 8, 264 | "request_limit": 100, 265 | "request_limit_period": 60 266 | } 267 | ``` 268 | 269 | ## Advanced Features 270 | 271 | ### Response Masking 272 | 273 | RewardAnything automatically masks model names during evaluation to prevent bias: 274 | 275 | ```python 276 | result = model.judge( 277 | principle="Judge based on helpfulness", 278 | prompt="How to cook pasta?", 279 | responses={ 280 | "gpt4": "Boil water, add pasta, cook for 8-10 minutes...", 281 | "claude": "Start by bringing a large pot of salted water to boil..." 282 | }, 283 | mask_responses=True # Default: True 284 | ) 285 | # Model sees "model-1", "model-2" instead of "gpt4", "claude" 286 | ``` 287 | 288 | ### Batch Processing 289 | 290 | ```python 291 | # Local batch processing 292 | requests = [ 293 | { 294 | "principle": "Prefer technical accuracy", 295 | "prompt": "Explain machine learning", 296 | "responses": {...} 297 | }, 298 | { 299 | "principle": "Favor practical examples", 300 | "prompt": "How to debug code?", 301 | "responses": {...} 302 | } 303 | ] 304 | 305 | results = model.judge_batch(requests, batch_size=4) 306 | 307 | # Remote batch processing 308 | results = client.judge_batch(requests) 309 | ``` 310 | 311 | ### Custom Principles 312 | 313 | RewardAnything excels with sophisticated, multi-criteria principles: 314 | 315 | ```python 316 | complex_principle = """ 317 | Evaluate responses using these criteria: 318 | 319 | 1. **Technical Accuracy** (40%): 320 | - Factual correctness 321 | - Up-to-date information 322 | - Proper terminology 323 | 324 | 2. **Clarity** (30%): 325 | - Clear explanations 326 | - Logical structure 327 | - Appropriate detail level 328 | 329 | 3. **Practical Value** (20%): 330 | - Actionable advice 331 | - Real-world applicability 332 | - Concrete examples 333 | 334 | 4. **Safety** (10%): 335 | - No harmful content 336 | - Appropriate disclaimers 337 | - Ethical considerations 338 | 339 | For conflicting criteria, prioritize safety > accuracy > clarity > practical value. 340 | """ 341 | 342 | result = model.judge(complex_principle, prompt, responses) 343 | ``` 344 | 345 | ## Testing 346 | 347 | Run the test suite: 348 | 349 | ```bash 350 | # All tests 351 | pytest 352 | 353 | # Specific test modules 354 | pytest tests/test_local.py -v 355 | pytest tests/test_client.py -v 356 | pytest tests/test_server.py -v 357 | 358 | # With coverage 359 | pytest --cov=rewardanything tests/ 360 | ``` 361 | 362 | ## Contributing 363 | 364 | See [CONTRIBUTING.md](../CONTRIBUTING.md) for development guidelines. 365 | 366 | ### Development Workflow 367 | 368 | 1. Fork the repository 369 | 2. Create a feature branch 370 | 3. Make changes with tests 371 | 4. Run tests and linting 372 | 5. Submit a pull request 373 | 374 | ```bash 375 | # Development setup 376 | git clone https://github.com/your-username/RewardAnything.git 377 | cd RewardAnything 378 | pip install -e ".[dev]" 379 | 380 | # Pre-commit hooks 381 | pre-commit install 382 | 383 | # Run tests 384 | pytest 385 | 386 | # Code formatting 387 | black rewardanything/ 388 | isort rewardanything/ 389 | 390 | # Type checking 391 | mypy rewardanything/ 392 | ``` 393 | 394 | ## Troubleshooting 395 | 396 | ### Common Issues 397 | 398 | 1. **CUDA Out of Memory** 399 | ```python 400 | # Use smaller model or CPU 401 | model = rewardanything.from_pretrained( 402 | "RewardAnything/RewardAnything-1B", # Smaller model 403 | device="cpu" # Or use CPU 404 | ) 405 | ``` 406 | 407 | 2. **Server Connection Issues** 408 | ```python 409 | # Check server health 410 | client = rewardanything.Client("http://localhost:8000") 411 | if not client.health_check(): 412 | print("Server is not responding") 413 | ``` 414 | 415 | 3. **Rate Limiting** 416 | ```python 417 | # Adjust client timeout and retries 418 | client = rewardanything.Client( 419 | base_url="http://localhost:8000", 420 | timeout=120.0, # Longer timeout 421 | max_retries=5 # More retries 422 | ) 423 | ``` 424 | 425 | ### Performance Optimization 426 | 427 | 1. **Use appropriate hardware** 428 | - GPU with sufficient VRAM for local inference 429 | - Multiple workers for server deployment 430 | 431 | 2. **Batch processing** 432 | - Use batch methods for multiple evaluations 433 | - Adjust batch size based on available memory 434 | 435 | 3. **Caching** 436 | - Server automatically caches responses 437 | - Use consistent request IDs for cache hits 438 | 439 | ## License 440 | 441 | Apache 2.0 License - see [LICENSE](../LICENSE) for details. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | 4 | 5 | RewardAnything 6 | 7 |
8 |

9 | Website 10 | Model Weights 11 | Paper 12 | PyPI 13 |

14 |
15 | 16 | # RewardAnything: Generalizable Principle-Following Reward Models 17 | 18 | Zhuohao Yu1,§  19 | Jiali Zeng2  20 | Weizheng Gu1  21 | Yidong Wang1  22 | Jindong Wang3  23 | Fandong Meng2  24 | Jie Zhou2  25 | Yue Zhang4  26 | Shikun Zhang1  27 | Wei Ye1,† 28 |
29 |
30 |

31 | 1Peking University  32 | 2WeChat AI  33 | 3William & Mary  34 | 4Westlake University 35 |

36 |

§Work done during Zhuohao's internship at Pattern Recognition Center, WeChat AI, Tencent Inc; Corresponding author.

37 |
38 |
39 | 40 | Traditional reward models learn **implicit preferences** from fixed datasets, leading to static judgments that struggle with the **nuanced and multifaceted nature of human values**. 41 | We believe that, much like Large Language Models follow diverse instructions, reward models must be able to understand and follow **explicitly specified principles**. 42 | 43 | **RewardAnything** embodies this new paradigm. Our models are designed to interpret natural language principles at inference time, enabling **dynamic adaptation** to a wide array of evaluation criteria **without costly retraining**. This approach shifts from fitting a single preference distribution to achieving true principle-following generalization. 44 | 45 | ## 🌟 Key Features 46 | 47 | - 🧠 **Principle-Following**: Directly interprets and applies reward criteria specified in natural language 48 | - 🔄 **Dynamic Adaptability**: Generalizes to new, unseen principles at inference time without retraining 49 | - 💰 **Resource Efficient**: Eliminates costly cycles of collecting preference data and retraining RMs 50 | - 📊 **State-of-the-Art Performance**: Achieves SOTA on RM-Bench and excels on our RABench benchmark 51 | - 🧩 **Easy Integration**: Works seamlessly with existing RLHF pipelines (PPO, GRPO) 52 | - 🔍 **Interpretable**: Provides transparent reasoning for evaluation decisions 53 | 54 | ## 🚀 Quick Start 55 | 56 | ### Installation 57 | 58 | ```bash 59 | pip install rewardanything 60 | ``` 61 | 62 | RewardAnything offers three flexible deployment options to fit your workflow: 63 | 64 | ## 1. 🏠 Local Inference (Recommended for Quick Testing) 65 | 66 | **Best for**: Quick experimentation, small-scale evaluation, research 67 | 68 | **Pros**: Simple setup, no external dependencies 69 | **Cons**: Requires local GPU, slower for batch processing 70 | 71 | ```python 72 | import rewardanything 73 | 74 | # Load model locally (similar to HuggingFace) 75 | reward_model = rewardanything.from_pretrained( 76 | "WisdomShell/RewardAnything-8B-v1", # Model path/name 77 | device="cuda", # Device placement 78 | torch_dtype="auto" # Automatic dtype selection 79 | ) 80 | 81 | # Define your evaluation principle 82 | principle = "I prefer clear, concise and helpful responses over long and detailed ones." 83 | 84 | # Your evaluation data 85 | prompt = "How do I learn Python programming effectively?" 86 | responses = { 87 | "response_a": "Start with Python.org's tutorial, practice daily with small projects, and join r/learnpython for help. Focus on fundamentals first.", 88 | "response_b": "Here's a comprehensive approach: 1) Start with Python basics including variables, data types, operators, control structures like if-statements, for-loops, while-loops, and functions, 2) Practice with small projects like calculators, text games, and data manipulation scripts, 3) Use interactive platforms like Codecademy, Python.org's official tutorial, edX courses, Coursera specializations, and YouTube channels, 4) Join communities like r/learnpython, Stack Overflow, Python Discord servers, and local meetups for support and networking, 5) Build progressively complex projects including web scrapers, APIs, data analysis tools, and web applications, 6) Read books like 'Automate the Boring Stuff', 'Python Crash Course', and 'Effective Python', 7) Dedicate 1-2 hours daily for consistent progress and track your learning journey.", 89 | "response_c": "Learn Python by coding." 90 | } 91 | 92 | # Get comprehensive evaluation 93 | result = reward_model.judge( 94 | principle=principle, 95 | prompt=prompt, 96 | responses=responses 97 | ) 98 | 99 | print(f"Scores: {result.scores}") 100 | print(f"Best to worst: {result.ranking}") 101 | print(f"Reasoning: {result.reasoning}") 102 | ``` 103 | 104 | ## 2. 🚀 vLLM Deployment (Recommended for Production & RL Training) 105 | 106 | **Best for**: High-throughput batch inference, RLHF training, production workloads 107 | 108 | **Pros**: Fast batch processing, optimized inference, scalable 109 | **Cons**: Requires vLLM setup 110 | 111 | ### Step 1: Setup vLLM Server 112 | 113 | First, install and start a vLLM server. See the [vLLM quickstart guide](https://docs.vllm.ai/en/latest/getting_started/quickstart.html#openai-compatible-server) for detailed instructions: 114 | 115 | ```bash 116 | # Install vLLM 117 | pip install vllm 118 | 119 | # Start vLLM server with RewardAnything model 120 | vllm serve WisdomShell/RewardAnything-8B-v1 \ 121 | --host 0.0.0.0 \ 122 | --port 8000 \ 123 | --max-model-len 8192 \ 124 | --tensor-parallel-size 1 125 | ``` 126 | 127 | ### Step 2: Configure RewardAnything Server 128 | 129 | Create a config file `config.json`: 130 | 131 | ```json 132 | { 133 | "api_key": ["dummy-key-for-vllm"], 134 | "api_model": "WisdomShell/RewardAnything-8B-v1", 135 | "api_base": ["http://localhost:8000/v1"], 136 | "api_timeout": 120.0, 137 | "generation_config": { 138 | "temperature": 0.0, 139 | "max_tokens": 4096 140 | }, 141 | "num_workers": 8, 142 | "request_limit": 500, 143 | "request_limit_period": 60 144 | } 145 | ``` 146 | 147 | ### Step 3: Start RewardAnything Server 148 | 149 | ```bash 150 | # Start the RewardAnything API server 151 | rewardanything serve -c config.json --port 8001 152 | ``` 153 | 154 | ### Step 4: Use in Your Code 155 | 156 | ```python 157 | import rewardanything 158 | 159 | # Connect to the RewardAnything server 160 | client = rewardanything.Client("http://localhost:8001") 161 | 162 | # Process batch requests efficiently 163 | requests = [ 164 | { 165 | "principle": "Prefer clear, concise and helpful responses over long and detailed ones.", 166 | "prompt": "How to learn programming?", 167 | "responses": { 168 | "assistant_a": "Start with Python, practice daily, build projects.", 169 | "assistant_b": "Read books and hope for the best.", 170 | "assistant_c": "Start with Python.org's tutorial, practice daily with small projects, and join r/learnpython for help. Focus on fundamentals first." 171 | } 172 | }, 173 | # ... more requests 174 | ] 175 | 176 | results = client.judge_batch(requests) 177 | for result in results: 178 | print(f"Winner: {result.ranking[0]}") 179 | ``` 180 | 181 | ## 3. 🔧 Direct HuggingFace Integration 182 | 183 | **Best for**: Custom workflows, advanced users, integration with existing HF pipelines 184 | 185 | **Pros**: Full control, custom processing 186 | **Cons**: Manual parsing required 187 | 188 | ```python 189 | from transformers import AutoTokenizer, AutoModelForCausalLM 190 | from rewardanything.processing import prepare_chat_messages, parse_rewardanything_output 191 | 192 | # Load model and tokenizer directly 193 | model = AutoModelForCausalLM.from_pretrained( 194 | "WisdomShell/RewardAnything-8B-v1", 195 | torch_dtype="auto", 196 | device_map="auto" 197 | ) 198 | tokenizer = AutoTokenizer.from_pretrained("WisdomShell/RewardAnything-8B-v1") 199 | 200 | # Prepare evaluation data 201 | principle = "Judge responses based on helpfulness and accuracy" 202 | prompt = "What is the capital of France?" 203 | responses = { 204 | "model_a": "Paris is the capital of France.", 205 | "model_b": "I think it might be Lyon or Paris." 206 | } 207 | 208 | # Prepare chat messages (handles masking automatically) 209 | messages, masked2real = prepare_chat_messages(principle, prompt, responses) 210 | 211 | # Format with chat template 212 | formatted_input = tokenizer.apply_chat_template( 213 | messages, tokenize=False, add_generation_prompt=True 214 | ) 215 | 216 | # Generate response 217 | inputs = tokenizer(formatted_input, return_tensors="pt").to(model.device) 218 | with torch.no_grad(): 219 | outputs = model.generate( 220 | **inputs, 221 | max_new_tokens=4096, 222 | temperature=0.1, 223 | do_sample=True, 224 | pad_token_id=tokenizer.eos_token_id 225 | ) 226 | 227 | # Decode output 228 | generated_tokens = outputs[0][inputs.input_ids.shape[1]:] 229 | output_text = tokenizer.decode(generated_tokens, skip_special_tokens=True) 230 | 231 | # Parse structured results (handles JSON parsing robustly) 232 | result = parse_rewardanything_output(output_text, masked2real) 233 | 234 | print(f"Raw output: {output_text}") 235 | print(f"Parsed scores: {result.scores}") 236 | print(f"Ranking: {result.ranking}") 237 | print(f"Reasoning: {result.reasoning}") 238 | ``` 239 | 240 | ## 📊 When to Use Each Method 241 | 242 | | Use Case | Method | Why | 243 | |----------|--------|-----| 244 | | Quick testing | Local Inference | Simplest setup | 245 | | Research & development | Local Inference | Full control, easy debugging | 246 | | RLHF training | vLLM Deployment | High throughput, optimized for batches | 247 | | Production evaluation | vLLM Deployment | Scalable, reliable | 248 | | Large-scale evaluation | vLLM Deployment | Best performance | 249 | | Custom integration | Direct HuggingFace | Maximum flexibility | 250 | 251 | 252 | ## 🔬 Advanced Usage 253 | 254 | ### Custom Principles 255 | 256 | RewardAnything excels with sophisticated, multi-criteria principles: 257 | 258 | ```python 259 | complex_principle = """ 260 | Evaluate responses using these criteria: 261 | 1. **Technical Accuracy** (40%): Factual correctness and up-to-date information 262 | 2. **Clarity** (30%): Clear explanations and logical structure 263 | 3. **Practical Value** (20%): Actionable advice and real-world applicability 264 | 4. **Safety** (10%): No harmful content, appropriate disclaimers 265 | 266 | For conflicting criteria, prioritize: safety > accuracy > clarity > practical value. 267 | """ 268 | 269 | result = reward_model.judge(complex_principle, prompt, responses) 270 | ``` 271 | 272 | ### Integration with RLHF 273 | 274 | ```python 275 | # Example: Use in PPO training loop 276 | def reward_function(principle, prompt, response): 277 | result = reward_model.judge( 278 | principle=principle, 279 | prompt=prompt, 280 | responses={"generated": response, "reference": "baseline response"} 281 | ) 282 | return result.scores["generated"] 283 | 284 | # Use in your RLHF training 285 | rewards = [reward_function(principle, prompt, resp) for resp in generated_responses] 286 | ``` 287 | 288 | ### Response Masking 289 | 290 | RewardAnything automatically masks model names to prevent bias: 291 | 292 | ```python 293 | result = reward_model.judge( 294 | principle="Judge based on helpfulness", 295 | prompt="How to cook pasta?", 296 | responses={ 297 | "gpt4": "Boil water, add pasta...", 298 | "claude": "Start by bringing water to boil..." 299 | }, 300 | mask_responses=True # Default: True, model sees "model-1", "model-2" 301 | ) 302 | ``` 303 | 304 | ## 📈 Performance & Benchmarks 305 | 306 | Please refer to our paper for performance metrics and comparison. 307 | 308 | ## 📚 Documentation 309 | 310 | - [Full Documentation](docs/PROJECT_DOCS.md) 311 | - [API Reference](docs/api.md) 312 | - [Examples](examples/) 313 | - [Configuration Guide](docs/configuration.md) 314 | 315 | ## 🤝 Contributing 316 | 317 | We welcome contributions! See [CONTRIBUTING.md](CONTRIBUTING.md) for guidelines. 318 | 319 | ## 📄 Citation 320 | 321 | ```bibtex 322 | @article{yu2025rewardanything, 323 | title={RewardAnything: Generalizable Principle-Following Reward Models}, 324 | author={Yu, Zhuohao and Zeng, Jiali and Gu, Weizheng and Wang, Yidong and Wang, Jindong and Meng, Fandong and Zhou, Jie and Zhang, Yue and Zhang, Shikun and Ye, Wei}, 325 | journal={arXiv preprint arXiv:2506.03637}, 326 | year={2025} 327 | } 328 | ``` 329 | 330 | ## 📝 License 331 | 332 | This project is licensed under the Apache 2.0 License - see the [LICENSE](LICENSE) file for details. 333 | 334 | ## 🙏 Acknowledgments 335 | 336 | Special thanks to the open-source community and all contributors who made this project possible. 337 | -------------------------------------------------------------------------------- /rewardanything/processing.py: -------------------------------------------------------------------------------- 1 | """Unified input preparation and output parsing for RewardAnything.""" 2 | 3 | import json 4 | import random 5 | import re 6 | from typing import Dict, List, Optional, Tuple, Any 7 | import numpy as np 8 | from scipy.stats import kendalltau, spearmanr 9 | from .models import RewardRequest, RewardResult 10 | 11 | # System prompt for RewardAnything evaluation 12 | SYSTEM_PROMPT = """ 13 | You are an evaluator judging model responses based on a given **evaluation principle**. Your primary goal is to assess how well each response for the prompt adheres to the principle, placing this above typical general preferences, though you should not endorse harmful content. 14 | 15 | Your task: 16 | 1. Read the principle, prompt, and all responses carefully and consider how each response aligns with the principle, **briefly in a concise thinking process** 17 | 2. Score each response from 1-5: 18 | - 5: Perfect adherence + excellent quality 19 | - 4: Strong adherence with minor limitations 20 | - 3: Basic adherence 21 | - 2: Partial adherence with key omissions 22 | - 1: Poor adherence or contradicts principle 23 | 3. Sort responses from best to worst (distinguish between same scores) 24 | 25 | Use the scoring scale accurately based on merit - don't compress scores if responses show significant quality differences. If responses vary substantially in quality, utilize the full range (1-5) to reflect these differences. 26 | 27 | Output ONLY this JSON format: 28 | { 29 | "scores": {"model-1": 2, "model-2": 4, ...}, 30 | "best-to-worst": ["model-2", "model-1", ...] 31 | } 32 | """ 33 | 34 | 35 | def make_masked_responses(responses: Dict[str, str]) -> Tuple[Dict[str, str], Dict[str, str]]: 36 | """ 37 | Create masked model names to prevent bias in evaluation. 38 | Returns masked responses and a mapping from masked to real names. 39 | """ 40 | masked_names = [f'model-{i+1}' for i in range(len(responses))] 41 | realname_responses = list(responses.items()) 42 | random.seed(42) # For reproducibility 43 | random.shuffle(realname_responses) 44 | 45 | masked2real = {masked_names[i]: realname_responses[i][0] for i in range(len(realname_responses))} 46 | responses_masked = {masked_names[i]: realname_responses[i][1] for i in range(len(realname_responses))} 47 | 48 | return responses_masked, masked2real 49 | 50 | 51 | def prepare_chat_messages( 52 | principle: str, 53 | prompt: str, 54 | responses: Dict[str, str], 55 | mask_responses: bool = True 56 | ) -> Tuple[List[Dict[str, str]], Optional[Dict[str, str]]]: 57 | """ 58 | Prepare chat messages for RewardAnything evaluation. 59 | 60 | Returns: 61 | Tuple of (messages, masked2real_mapping) 62 | """ 63 | # Apply masking if requested 64 | masked2real = None 65 | if mask_responses: 66 | responses_to_use, masked2real = make_masked_responses(responses) 67 | else: 68 | responses_to_use = responses 69 | 70 | # Create user content as JSON 71 | user_content = json.dumps({ 72 | "principle": principle, 73 | "prompt": prompt, 74 | "responses": responses_to_use, 75 | }, ensure_ascii=False) 76 | 77 | # Construct messages 78 | messages = [ 79 | {"role": "system", "content": SYSTEM_PROMPT}, 80 | {"role": "user", "content": user_content} 81 | ] 82 | 83 | return messages, masked2real 84 | 85 | 86 | def extract_json_content(text: str) -> str: 87 | """Extract JSON content after the tag.""" 88 | # First remove the thinking section 89 | after_think = re.sub(r'.*?', '', text, flags=re.DOTALL) 90 | 91 | # Clean up any markdown code block indicators 92 | clean_text = re.sub(r'```(?:json)?|```', '', after_think) 93 | 94 | # Return the cleaned text that should contain just JSON 95 | return clean_text.strip() 96 | 97 | 98 | def fix_json_format(json_str: str) -> str: 99 | """Fix common JSON formatting issues.""" 100 | # First, replace all single quotes with double quotes 101 | # This is a simple approach but works for most cases 102 | fixed = json_str.replace("'", '"') 103 | 104 | # Add quotes to unquoted property names (if any remain) 105 | fixed = re.sub(r'([{,])\s*([a-zA-Z0-9_-]+)\s*:', r'\1"\2":', fixed) 106 | 107 | # Handle trailing commas which are invalid in JSON 108 | fixed = re.sub(r',\s*}', '}', fixed) 109 | fixed = re.sub(r',\s*]', ']', fixed) 110 | 111 | # Fix common typos in keys 112 | fixed = fixed.replace('"best-to-worse"', '"best-to-worst"') 113 | 114 | return fixed 115 | 116 | 117 | def extract_json_by_brackets(text: str) -> Optional[str]: 118 | """Extract JSON by finding first { and last }.""" 119 | first_brace = text.find('{') 120 | last_brace = text.rfind('}') 121 | 122 | if first_brace != -1 and last_brace != -1 and last_brace > first_brace: 123 | return text[first_brace:last_brace + 1] 124 | return None 125 | 126 | 127 | def safe_parse_json(text: str) -> Tuple[Dict, bool]: 128 | """ 129 | Safely parse JSON with multiple fallbacks. 130 | Returns (parsed_dict, used_bracket_extraction) 131 | """ 132 | # First attempt: direct parsing 133 | try: 134 | return json.loads(text), False 135 | except json.JSONDecodeError: 136 | pass 137 | 138 | # Second attempt: fix common formatting issues 139 | try: 140 | fixed = fix_json_format(text) 141 | return json.loads(fixed), False 142 | except: 143 | pass 144 | 145 | # Third attempt: extract by brackets 146 | try: 147 | bracket_json = extract_json_by_brackets(text) 148 | if bracket_json: 149 | # Try to parse the bracket-extracted content 150 | result = json.loads(bracket_json) 151 | if isinstance(result, dict): 152 | return result, True 153 | except: 154 | pass 155 | 156 | # Final fallback 157 | return {}, False 158 | 159 | 160 | def calculate_weighted_pair_penalty(pred_scores: Dict[str, int], gt_scores: Dict[str, int]) -> float: 161 | """ 162 | Calculate weighted pair reversal penalty. 163 | Higher penalty for reversing pairs with larger score differences. 164 | """ 165 | common_models = list(set(pred_scores.keys()) & set(gt_scores.keys())) 166 | 167 | if len(common_models) < 2: 168 | return 1.0 # No pairs to compare 169 | 170 | total_penalty = 0.0 171 | total_weight = 0.0 172 | 173 | # Compare all pairs 174 | for i in range(len(common_models)): 175 | for j in range(i + 1, len(common_models)): 176 | model_a, model_b = common_models[i], common_models[j] 177 | 178 | try: 179 | # Get scores 180 | pred_a = float(pred_scores[model_a]) 181 | pred_b = float(pred_scores[model_b]) 182 | gt_a = float(gt_scores[model_a]) 183 | gt_b = float(gt_scores[model_b]) 184 | 185 | # Calculate ground truth difference (weight) 186 | gt_diff = abs(gt_a - gt_b) 187 | weight = 1 + gt_diff # Base weight of 1 plus difference 188 | 189 | # Check if pair is correctly ordered 190 | pred_order = pred_a - pred_b 191 | gt_order = gt_a - gt_b 192 | 193 | # If signs differ, the pair is reversed 194 | if pred_order * gt_order < 0: 195 | penalty = weight 196 | elif pred_order == 0 and gt_order != 0: 197 | # Predicted tie when there shouldn't be one 198 | penalty = weight * 0.5 199 | else: 200 | penalty = 0 201 | 202 | total_penalty += penalty 203 | total_weight += weight 204 | 205 | except (ValueError, TypeError): 206 | continue 207 | 208 | if total_weight == 0: 209 | return 0.0 210 | 211 | # Convert to reward (higher is better) 212 | return max(0, 1 - total_penalty / total_weight) 213 | 214 | 215 | def format_reward(predict: str, ground_truth: str) -> float: 216 | """ 217 | Enhanced format reward with more granular scoring: 218 | 1. Thinking tags presence and quality 219 | 2. JSON structure validity 220 | 3. Required keys completeness 221 | 4. Model coverage 222 | 5. Consistency between scores and best-to-worst 223 | """ 224 | score = 0.0 225 | 226 | # Check for thinking section (20%) 227 | has_think_open = '' in predict 228 | has_think_close = '' in predict 229 | if has_think_open and has_think_close: 230 | score += 0.15 231 | # Bonus for substantial thinking (5%) 232 | think_match = re.search(r'(.*?)', predict, re.DOTALL) 233 | if think_match and len(think_match.group(1).strip()) > 50: 234 | score += 0.05 235 | elif has_think_open or has_think_close: 236 | # Partial credit for attempting 237 | score += 0.05 238 | 239 | # Extract and parse JSON 240 | json_content = extract_json_content(predict) 241 | pred_json, used_bracket_extraction = safe_parse_json(json_content) 242 | gt_json, _ = safe_parse_json(ground_truth) 243 | 244 | # Valid JSON structure (20%) 245 | if pred_json: 246 | # Apply penalty if bracket extraction was used 247 | json_score = 0.20 if not used_bracket_extraction else 0.15 248 | score += json_score 249 | 250 | # Required keys (20%) 251 | has_scores = "scores" in pred_json 252 | # Accept both "best-to-worst" and "best-to-worse" 253 | has_ranking = "best-to-worst" in pred_json or "best-to-worse" in pred_json 254 | 255 | if has_scores: 256 | score += 0.10 257 | if has_ranking: 258 | score += 0.10 259 | 260 | # Model coverage (20%) 261 | if has_scores and "scores" in gt_json: 262 | gt_models = set(gt_json["scores"].keys()) 263 | pred_models = set(pred_json["scores"].keys()) if isinstance(pred_json["scores"], dict) else set() 264 | 265 | if gt_models and pred_models: 266 | coverage = len(gt_models & pred_models) / max(1, len(gt_models)) 267 | score += 0.20 * coverage 268 | 269 | # Consistency check (20%) 270 | if has_scores and has_ranking: 271 | consistency_score = check_ranking_consistency(pred_json) 272 | score += 0.20 * consistency_score 273 | 274 | return min(1.0, score) 275 | 276 | 277 | def parse_rewardanything_output( 278 | output_text: str, 279 | masked2real: Optional[Dict[str, str]] = None 280 | ) -> RewardResult: 281 | """ 282 | Parse RewardAnything model output into structured result. 283 | 284 | Args: 285 | output_text: Raw output from the model 286 | masked2real: Optional mapping from masked names to real names 287 | 288 | Returns: 289 | RewardResult with scores, ranking, and reasoning 290 | """ 291 | # Extract thinking section 292 | think_start = output_text.find("") 293 | think_end = output_text.find("") 294 | 295 | if think_start != -1 and think_end != -1: 296 | reasoning = output_text[think_start + 7:think_end].strip() 297 | else: 298 | reasoning = "" 299 | 300 | # Extract and parse JSON 301 | json_content = extract_json_content(output_text) 302 | results, _ = safe_parse_json(json_content) 303 | 304 | # Ensure the results have the expected structure 305 | if not results: 306 | results = {"scores": {}, "best-to-worst": []} 307 | elif "scores" not in results: 308 | results["scores"] = {} 309 | elif "best-to-worst" not in results: 310 | # Handle both key variants 311 | if "best-to-worse" in results: 312 | results["best-to-worst"] = results["best-to-worse"] 313 | else: 314 | results["best-to-worst"] = [] 315 | 316 | scores = results.get("scores", {}) 317 | ranking = results.get("best-to-worst", []) 318 | 319 | # Convert masked names back to real names if needed 320 | if masked2real: 321 | # Convert scores 322 | real_scores = {} 323 | for masked_name, score in scores.items(): 324 | if masked_name in masked2real: 325 | real_scores[masked2real[masked_name]] = score 326 | else: 327 | real_scores[masked_name] = score # Keep unmapped names as-is 328 | 329 | # Convert ranking 330 | real_ranking = [] 331 | for masked_name in ranking: 332 | if masked_name in masked2real: 333 | real_ranking.append(masked2real[masked_name]) 334 | else: 335 | real_ranking.append(masked_name) # Keep unmapped names as-is 336 | 337 | scores = real_scores 338 | ranking = real_ranking 339 | 340 | return RewardResult( 341 | reasoning=reasoning, 342 | scores=scores, 343 | ranking=ranking, 344 | raw_output=output_text 345 | ) 346 | 347 | 348 | def analyze_prediction_quality(predict: str, ground_truth: str) -> Dict: 349 | """ 350 | Detailed analysis of a single prediction for debugging. 351 | """ 352 | try: 353 | pred_json, used_brackets = safe_parse_json(extract_json_content(predict)) 354 | gt_json, _ = safe_parse_json(ground_truth) 355 | 356 | analysis = { 357 | "has_thinking": '' in predict and '' in predict, 358 | "valid_json": bool(pred_json), 359 | "used_bracket_extraction": used_brackets, 360 | "has_required_keys": bool(pred_json.get("scores")) and 361 | (bool(pred_json.get("best-to-worst")) or bool(pred_json.get("best-to-worse"))), 362 | "model_coverage": 0.0, 363 | "score_differences": {}, 364 | "reversed_pairs": [] 365 | } 366 | 367 | if pred_json and gt_json: 368 | pred_scores = pred_json.get("scores", {}) 369 | gt_scores = gt_json.get("scores", {}) 370 | 371 | common = set(pred_scores.keys()) & set(gt_scores.keys()) 372 | if gt_scores: 373 | analysis["model_coverage"] = len(common) / len(gt_scores) 374 | 375 | # Analyze score differences 376 | for model in common: 377 | try: 378 | analysis["score_differences"][model] = { 379 | "predicted": float(pred_scores[model]), 380 | "ground_truth": float(gt_scores[model]), 381 | "difference": abs(float(pred_scores[model]) - float(gt_scores[model])) 382 | } 383 | except: 384 | pass 385 | 386 | # Find reversed pairs 387 | models = list(common) 388 | for i in range(len(models)): 389 | for j in range(i + 1, len(models)): 390 | try: 391 | pred_diff = float(pred_scores[models[i]]) - float(pred_scores[models[j]]) 392 | gt_diff = float(gt_scores[models[i]]) - float(gt_scores[models[j]]) 393 | 394 | if pred_diff * gt_diff < 0: 395 | analysis["reversed_pairs"].append({ 396 | "pair": (models[i], models[j]), 397 | "weight": abs(gt_diff) 398 | }) 399 | except: 400 | pass 401 | 402 | return analysis 403 | 404 | except: 405 | return {"error": "Analysis failed"} -------------------------------------------------------------------------------- /rewardanything/serve.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | import asyncio 4 | import json 5 | import logging 6 | import os 7 | import shutil 8 | import time 9 | import uuid 10 | from datetime import datetime 11 | from typing import Dict, List, Optional, Any, Union 12 | 13 | import uvicorn 14 | from fastapi import FastAPI, HTTPException, BackgroundTasks 15 | from pydantic import BaseModel 16 | 17 | # Updated imports to use unified processing 18 | from .utils import OpenAIClient, AsyncRateLimitThreadPool 19 | from .models import RewardRequest, RewardResponse 20 | from .processing import prepare_chat_messages, parse_rewardanything_output 21 | 22 | # Configure logging 23 | logging.basicConfig(level=logging.INFO, 24 | format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') 25 | logger = logging.getLogger(__name__) 26 | 27 | app = FastAPI(title="RewardAnything API") 28 | 29 | # Store for batch tasks 30 | batch_tasks = {} 31 | 32 | 33 | class BatchRequest(BaseModel): 34 | requests: List[Dict] 35 | output_path: Optional[str] = None 36 | api_model: Optional[str] = None 37 | api_key: Optional[Union[str, List[str]]] = None 38 | api_base: Optional[Union[str, List[str]]] = None 39 | api_proxy: Optional[Union[str, List[str]]] = None 40 | api_timeout: Optional[float] = None 41 | api_max_retries: Optional[int] = None 42 | generation_config: Optional[Dict] = None 43 | num_workers: Optional[int] = None 44 | request_limit: Optional[int] = None 45 | request_limit_period: Optional[int] = None 46 | max_error_count: Optional[int] = None 47 | trial_run: Optional[bool] = None 48 | dump_individual_rsp: Optional[bool] = None 49 | 50 | 51 | class BatchRequestResponse(BaseModel): 52 | batch_request_id: str 53 | requests: int 54 | start_time: str 55 | 56 | 57 | class BatchTask: 58 | def __init__(self, request: BatchRequest, config: Dict, base_output_path: str): 59 | self.id = str(uuid.uuid4()) 60 | self.request = request 61 | self.config = config 62 | self.base_output_path = base_output_path 63 | self.start_time = datetime.now().isoformat() 64 | self.complete = False 65 | self.results = None 66 | self.error = None 67 | self.task = None 68 | 69 | def get_inference_params(self): 70 | # Start with default config 71 | params = self.config.copy() 72 | 73 | # Override with request params if provided 74 | for key, value in self.request.dict(exclude_unset=True).items(): 75 | if value is not None: 76 | params[key] = value 77 | 78 | # Ensure output path is set using base_output_path and batch_id 79 | if not params.get("output_path") or params["output_path"] is None: 80 | params["output_path"] = os.path.join(self.base_output_path, self.id) 81 | 82 | # Make sure the output path exists 83 | os.makedirs(params["output_path"], exist_ok=True) 84 | 85 | return params 86 | 87 | 88 | async def run_inference_without_asyncio_run( 89 | requests, 90 | output_path, 91 | api_model, 92 | api_key, 93 | api_base=None, 94 | api_proxy=None, 95 | api_timeout=30.0, 96 | api_max_retries=5, 97 | generation_config=None, 98 | num_workers=8, 99 | request_limit=100, 100 | request_limit_period=60, 101 | max_error_count=100, 102 | trial_run=False, 103 | dump_individual_rsp=True, 104 | ): 105 | """Modified version of run_api_inference that works within an existing event loop""" 106 | 107 | logger.info(f"num_requests: {len(requests)}, output_path: {output_path}") 108 | 109 | os.makedirs(output_path, exist_ok=True) 110 | 111 | if dump_individual_rsp: 112 | os.makedirs(os.path.join(output_path, "responses"), exist_ok=True) 113 | 114 | if os.path.exists(os.path.join(output_path, "all_responses.jsonl")): 115 | os.remove(os.path.join(output_path, "all_responses.jsonl")) 116 | 117 | client = OpenAIClient( 118 | output_path=output_path, 119 | api_model=api_model, 120 | api_key=api_key, 121 | api_base=api_base, 122 | api_proxy=api_proxy, 123 | api_timeout=api_timeout, 124 | api_max_retries=api_max_retries, 125 | generation_config=generation_config, 126 | trial_run=trial_run, 127 | dump_individual_rsp=dump_individual_rsp, 128 | max_error_count=max_error_count, 129 | ) 130 | 131 | # Create a pool and run queries 132 | pool = AsyncRateLimitThreadPool(num_workers, request_limit, request_limit_period) 133 | writer_task = asyncio.create_task(client.write_responses_to_file()) 134 | 135 | try: 136 | coroutines = [pool._rate_limited_call(client.query, req) for req in requests] 137 | results = await asyncio.gather(*coroutines) 138 | 139 | # Wait for all responses to be written 140 | await client.response_queue.join() 141 | finally: 142 | # Always cancel the writer task 143 | writer_task.cancel() 144 | try: 145 | await writer_task 146 | except asyncio.CancelledError: 147 | pass 148 | 149 | return results 150 | 151 | 152 | async def run_batch_task(task: BatchTask): 153 | """Execute a batch task asynchronously.""" 154 | try: 155 | # Get inference parameters 156 | params = task.get_inference_params() 157 | 158 | # Convert batch request to OpenAI request format 159 | api_requests = [] 160 | for i, req in enumerate(task.request.requests): 161 | # Prepare messages using unified processing 162 | messages, _ = prepare_chat_messages( 163 | req["principle"], 164 | req["prompt"], 165 | req["responses"], 166 | req.get("mask_responses", True) 167 | ) 168 | 169 | api_requests.append({ 170 | "uuid": f"batch_{task.id}_{i}", 171 | "messages": messages 172 | }) 173 | 174 | # Run inference 175 | await run_inference_without_asyncio_run( 176 | requests=api_requests, 177 | **{k: v for k, v in params.items() if k != "requests"} 178 | ) 179 | 180 | # Load results 181 | results_path = os.path.join(params["output_path"], "all_responses.jsonl") 182 | results = [] 183 | 184 | if os.path.exists(results_path): 185 | with open(results_path, 'r', encoding='utf-8') as f: 186 | for line in f: 187 | result_data = json.loads(line.strip()) 188 | content = result_data["response"]["generated_text"] 189 | 190 | # Parse using unified processing 191 | reward_result = parse_rewardanything_output(content, None) # No masking for batch 192 | 193 | results.append({ 194 | "thoughts": reward_result.reasoning, 195 | "results": { 196 | "scores": reward_result.scores, 197 | "best-to-worst": reward_result.ranking 198 | } 199 | }) 200 | 201 | task.results = results 202 | task.complete = True 203 | 204 | except Exception as e: 205 | logger.exception(f"Batch task {task.id} failed") 206 | task.error = str(e) 207 | task.complete = True 208 | 209 | 210 | @app.post("/api/rewardanything", response_model=RewardResponse) 211 | async def rewardanything_single(request: RewardRequest): 212 | """Single RewardAnything evaluation.""" 213 | # Get config from app state 214 | config = app.state.config 215 | 216 | # Prepare messages using unified processing 217 | messages, masked2real = prepare_chat_messages( 218 | request.principle, 219 | request.prompt, 220 | request.responses, 221 | request.mask_responses 222 | ) 223 | 224 | # Create a unique ID for this request 225 | request_id = str(uuid.uuid4()) 226 | 227 | # Prepare the request for OpenAI inference 228 | api_request = { 229 | "uuid": request_id, 230 | "messages": messages 231 | } 232 | 233 | # Create temporary output directory 234 | output_path = os.path.join(app.state.base_output_path, f"single_{request_id}") 235 | 236 | try: 237 | # Run inference 238 | await run_inference_without_asyncio_run( 239 | requests=[api_request], 240 | output_path=output_path, 241 | api_model=config.get("api_model"), 242 | api_key=config.get("api_key"), 243 | api_base=config.get("api_base"), 244 | api_proxy=config.get("api_proxy"), 245 | api_timeout=config.get("api_timeout", 120.0), 246 | api_max_retries=config.get("api_max_retries", 5), 247 | generation_config=config.get("generation_config"), 248 | num_workers=1, 249 | request_limit=config.get("request_limit", 100), 250 | request_limit_period=config.get("request_limit_period", 60), 251 | max_error_count=config.get("max_error_count", 50), 252 | dump_individual_rsp=False 253 | ) 254 | 255 | # Load and parse result 256 | results_path = os.path.join(output_path, "all_responses.jsonl") 257 | if not os.path.exists(results_path): 258 | raise HTTPException(status_code=500, detail="Failed to generate response") 259 | 260 | with open(results_path, 'r', encoding='utf-8') as f: 261 | result_data = json.loads(f.readline().strip()) 262 | 263 | # Parse the response content using unified processing 264 | content = result_data["response"]["generated_text"] 265 | reward_result = parse_rewardanything_output(content, masked2real) 266 | 267 | # Convert to API response format 268 | response = RewardResponse( 269 | thoughts=reward_result.reasoning, 270 | results={ 271 | "scores": reward_result.scores, 272 | "best-to-worst": reward_result.ranking 273 | } 274 | ) 275 | 276 | # Cleanup temporary directory 277 | try: 278 | shutil.rmtree(output_path) 279 | except Exception as cleanup_error: 280 | logger.warning(f"Failed to cleanup directory {output_path}: {cleanup_error}") 281 | 282 | return response 283 | 284 | except Exception as e: 285 | # Cleanup on error as well 286 | try: 287 | if os.path.exists(output_path): 288 | shutil.rmtree(output_path) 289 | except Exception as cleanup_error: 290 | logger.warning(f"Failed to cleanup directory {output_path}: {cleanup_error}") 291 | 292 | logger.exception("Error processing reward request") 293 | raise HTTPException(status_code=500, detail=f"Error: {str(e)}") 294 | 295 | 296 | @app.post("/api/rewardanything_batch") 297 | async def rewardanything_batch(requests: List[RewardRequest], max_timeout: int = 600): 298 | """Batch RewardAnything evaluation.""" 299 | if not requests: 300 | raise HTTPException(status_code=400, detail="No requests provided") 301 | 302 | # Get config from app state 303 | config = app.state.config 304 | 305 | # Create a unique batch ID 306 | batch_id = str(uuid.uuid4()) 307 | batch_output_path = os.path.join(app.state.base_output_path, f"batch_{batch_id}") 308 | 309 | try: 310 | # Prepare all requests using unified processing 311 | api_requests = [] 312 | masked_mappings = {} # Store masked->real mappings for each request 313 | 314 | for i, req in enumerate(requests): 315 | request_id = f"{batch_id}_{i}" 316 | 317 | # Prepare messages 318 | messages, masked2real = prepare_chat_messages( 319 | req.principle, 320 | req.prompt, 321 | req.responses, 322 | req.mask_responses 323 | ) 324 | 325 | # Store the mapping for later use 326 | if masked2real: 327 | masked_mappings[request_id] = masked2real 328 | 329 | api_requests.append({ 330 | "uuid": request_id, 331 | "messages": messages 332 | }) 333 | 334 | # Run batch inference 335 | dump_individual_rsp = config.get("dump_individual_rsp", True) 336 | await run_inference_without_asyncio_run( 337 | requests=api_requests, 338 | output_path=batch_output_path, 339 | api_model=config.get("api_model"), 340 | api_key=config.get("api_key"), 341 | api_base=config.get("api_base"), 342 | api_proxy=config.get("api_proxy"), 343 | api_timeout=config.get("api_timeout", 120.0), 344 | api_max_retries=config.get("api_max_retries", 5), 345 | generation_config=config.get("generation_config"), 346 | num_workers=config.get("num_workers", 8), 347 | request_limit=config.get("request_limit", 100), 348 | request_limit_period=config.get("request_limit_period", 60), 349 | max_error_count=config.get("max_error_count", 50), 350 | dump_individual_rsp=dump_individual_rsp 351 | ) 352 | 353 | # Load and process the results 354 | results_path = os.path.join(batch_output_path, "all_responses.jsonl") 355 | if not os.path.exists(results_path): 356 | raise HTTPException(status_code=500, detail="Failed to generate responses") 357 | 358 | # Parse each response using unified processing 359 | results = [] 360 | with open(results_path, 'r', encoding='utf-8') as f: 361 | for line in f: 362 | result_data = json.loads(line.strip()) 363 | request_id = result_data["request"]["uuid"] 364 | content = result_data["response"]["generated_text"] 365 | 366 | # Get the masked->real mapping for this request 367 | masked2real = masked_mappings.get(request_id) 368 | 369 | # Parse the response using unified processing 370 | reward_result = parse_rewardanything_output(content, masked2real) 371 | 372 | # Convert to API response format 373 | response = RewardResponse( 374 | thoughts=reward_result.reasoning, 375 | results={ 376 | "scores": reward_result.scores, 377 | "best-to-worst": reward_result.ranking 378 | } 379 | ) 380 | results.append(response) 381 | 382 | # Cleanup temporary directory if dump_individual_rsp is enabled 383 | if dump_individual_rsp: 384 | try: 385 | shutil.rmtree(batch_output_path) 386 | except Exception as cleanup_error: 387 | logger.warning(f"Failed to cleanup directory {batch_output_path}: {cleanup_error}") 388 | 389 | return results 390 | 391 | except Exception as e: 392 | # Cleanup on error as well if dump_individual_rsp is enabled 393 | dump_individual_rsp = config.get("dump_individual_rsp", True) 394 | if dump_individual_rsp: 395 | try: 396 | if os.path.exists(batch_output_path): 397 | shutil.rmtree(batch_output_path) 398 | except Exception as cleanup_error: 399 | logger.warning(f"Failed to cleanup directory {batch_output_path}: {cleanup_error}") 400 | 401 | logger.exception(f"Error processing batch request: {e}") 402 | raise HTTPException(status_code=500, detail=f"Error: {str(e)}") 403 | 404 | 405 | @app.post("/api/batch_request", response_model=BatchRequestResponse) 406 | async def create_batch_request(request: BatchRequest, background_tasks: BackgroundTasks): 407 | """Create a new batch request for asynchronous processing.""" 408 | task = BatchTask(request, app.state.config, app.state.base_output_path) 409 | batch_tasks[task.id] = task 410 | 411 | # Start the task in the background 412 | background_tasks.add_task(run_batch_task, task) 413 | 414 | return BatchRequestResponse( 415 | batch_request_id=task.id, 416 | requests=len(request.requests), 417 | start_time=task.start_time 418 | ) 419 | 420 | 421 | @app.get("/api/batch_request/{batch_request_id}") 422 | async def get_batch_request(batch_request_id: str, max_timeout: int = 600): 423 | """Get the results of a batch request.""" 424 | if batch_request_id not in batch_tasks: 425 | raise HTTPException(status_code=404, detail="Batch request not found") 426 | 427 | task = batch_tasks[batch_request_id] 428 | 429 | start_time = time.time() 430 | while not task.complete and (time.time() - start_time) < max_timeout: 431 | await asyncio.sleep(1) 432 | 433 | if not task.complete: 434 | raise HTTPException(status_code=408, detail="Request timed out") 435 | 436 | if task.error: 437 | raise HTTPException(status_code=500, detail=f"Task failed: {task.error}") 438 | 439 | return task.results 440 | 441 | 442 | @app.post("/api/clear_tasks") 443 | async def clear_tasks(): 444 | batch_tasks.clear() 445 | return {"message": "All tasks cleared"} 446 | 447 | 448 | @app.get("/api/list_tasks") 449 | async def list_tasks(): 450 | return { 451 | task_id: { 452 | "start_time": task.start_time, 453 | "complete": task.complete, 454 | "has_error": task.error is not None, 455 | "num_requests": len(task.request.requests), 456 | "output_path": os.path.join(task.base_output_path, task.id) 457 | } 458 | for task_id, task in batch_tasks.items() 459 | } 460 | 461 | 462 | @app.get("/health") 463 | async def health_check(): 464 | """Health check endpoint""" 465 | return {"status": "healthy", "service": "RewardAnything API"} 466 | 467 | 468 | def load_config(config_path): 469 | with open(config_path, 'r') as f: 470 | return json.load(f) 471 | 472 | 473 | def create_app(config: Dict, base_output_path: str) -> FastAPI: 474 | """Create FastAPI app with configuration""" 475 | # Inject config into the app 476 | app.state.config = config 477 | app.state.base_output_path = base_output_path 478 | return app 479 | 480 | 481 | def main(): 482 | parser = argparse.ArgumentParser(description="RewardAnything API Server") 483 | parser.add_argument("-c", "--config", required=True, help="Path to configuration file") 484 | parser.add_argument("--port", type=int, default=8000, help="Port to listen on") 485 | parser.add_argument("--host", default="0.0.0.0", help="Host to bind to") 486 | parser.add_argument("--base-output-path", default="./outputs", 487 | help="Base directory for storing batch outputs") 488 | 489 | args = parser.parse_args() 490 | 491 | # Load configuration 492 | config = load_config(args.config) 493 | 494 | # Get absolute path for base output path 495 | base_output_path = os.path.abspath(args.base_output_path) 496 | 497 | # Create the output directory if it doesn't exist 498 | os.makedirs(base_output_path, exist_ok=True) 499 | 500 | # Create app with config 501 | app = create_app(config, base_output_path) 502 | 503 | logger.info(f"Server starting with base output path: {base_output_path}") 504 | 505 | # Run the server 506 | uvicorn.run(app, host=args.host, port=args.port) 507 | 508 | 509 | if __name__ == "__main__": 510 | main() -------------------------------------------------------------------------------- /pages/index.html: -------------------------------------------------------------------------------- 1 | --- 2 | layout: default 3 | title: "Home" 4 | description: "RewardAnything: Generalizable Principle-Following Reward Models" 5 | --- 6 | 7 | 8 |
9 |
10 |
11 | 12 | RewardAnything 15 |
16 | 17 |

18 | Generalizable Principle-Following Reward Models 19 |

20 | 21 | 22 |
23 |

24 | Traditional reward models learn implicit preferences from fixed data, 25 | but human values are too nuanced for any single, static model. 26 |

27 |

28 | We believe reward models, much like LLMs with instructions, must follow 29 | explicitly specified principles. 30 | This unlocks inference-time adaptability to diverse criteria—without costly retraining. 31 |

32 |
33 | 34 | 35 | 49 |
50 |
51 | 52 | 53 |
54 |
55 |
56 |
57 | 58 |
59 | Zhuohao Yu1,§ 60 | Jiali Zeng2 61 | Weizheng Gu1 62 | Yidong Wang1 63 | Jindong Wang3 64 |
65 | 66 |
67 | Fandong Meng2 68 | Jie Zhou2 69 | Yue Zhang4 70 | Shikun Zhang1 71 | Wei Ye1,† 72 |
73 |
74 |
75 |
76 |
77 | 1Peking University   78 | 2WeChat AI   79 | 3William & Mary   80 | 4Westlake University 81 |
82 |

83 | §Work done during internship at WeChat AI   Corresponding author 84 |

85 |
86 |
87 |
88 | 89 | 90 |
91 |
92 |
93 |

The Core Problem: Flawed Training & Narrow Evaluation

94 |

95 | Current reward models face fundamental limitations in how they are trained and evaluated, hindering their ability to truly align with diverse human values. 96 |

97 |
98 | 99 | 100 |
101 |

1. Problematic Training: Learning Static & Biased Preferences

102 |

103 | Reward models are typically trained on vast datasets of (prompt, chosen response, rejected response) tuples. This teaches the model a single, implicit preference distribution. 104 |

105 |
106 |
107 |

No Principled Control:

108 |

Even if the prompt and responses are identical, applying different evaluation principles (e.g., "be concise" vs. "be detailed") should lead to different rankings. Current RMs struggle to adapt this way without costly retraining for each new principle.

109 |
110 |
111 |

Implicit & Outcome-Only Learning:

112 |

Models learn what to prefer based on outcomes, but not why. This lack of explicit rationale leads to learning superficial patterns or spurious correlations (e.g., "longer is better," "lists are good") rather than the true underlying human intent, as shown below.

113 |
114 |
115 |
116 | 117 | 118 |
119 | 120 |
121 |
122 |
123 | 124 | 125 | 126 |
127 |
128 |

Issue #1: Length = Quality Bias

129 |

Models learn "longer responses are better" from pairs where correctness correlates with length.

130 |
131 |
132 | 133 |
134 |
135 |
🙋 Prompt:
136 |

"What are some species of bears that are now extinct?"

137 |
138 | 139 |
140 |
141 |
142 | ✅ Chosen (Long + Correct) 143 |
144 |

145 | "Several species of bears have become extinct... Cave Bear (Ursus spelaeus): One of the best-known extinct bear species... Short-faced Bear (Arctodus simus): Once the largest..." 146 |

147 |
148 | ✓ Accurate facts ✓ Detailed explanations 149 |
150 |
151 | 152 |
153 |
154 | ❌ Rejected (Short + Wrong) 155 |
156 |

157 | "Three species of bears that are now extinct are the woolly mammoth, the woolly rhinoceros, and the thylacine." 158 |

159 |
160 | ❌ Factually incorrect ❌ None are bears 161 |
162 |
163 |
164 | 165 |
166 |
⚠️ What Current RMs Learn:
167 |

A spurious correlation: "Longer responses are better." This preference is static, but what if the user actually preferred a brief, accurate answer?

168 |
169 |
170 |
171 | 172 | 173 |
174 |
175 |
176 | 177 | 178 | 179 |
180 |
181 |

Issue #2: Format Over Substance

182 |

Models often prioritize familiar structures (e.g., lists) over equally valid, natural content.

183 |
184 |
185 | 186 |
187 |
188 |
🙋 Prompt:
189 |

"What are some good browser alternatives to Chrome?"

190 |
191 | 192 |
193 |
194 |
195 | ✅ Chosen (Well-Structured) 196 |
197 |

198 | "There are several good browser alternatives to Chrome: 199 |
1. Mozilla Firefox: Known for strong privacy features, being open-source, and highly customizable. 200 |
2. Microsoft Edge: Now built on Chromium, offering good performance and compatibility." 201 |

202 |
203 | ✓ Clear, itemized structure ✓ Detailed points 204 |
205 |
206 | 207 |
208 |
209 | ⚠️ Rejected (Natural but Correct) 210 |
211 |

212 | "Sure! For browser alternatives, you could check out Firefox – it's really good for privacy and you can customize it a lot. Microsoft Edge is another option; it's pretty fast now that it uses Chromium tech." 213 |

214 |
215 | ✓ Factually correct ✓ Conversational style ✓ Complete 216 |
217 |
218 |
219 | 220 |
221 |
⚠️ What Current RMs Learn:
222 |

"Structured, list-like responses are better." This overlooks that a natural, conversational style might be equally informative or even preferred by some users.

223 |
224 |
225 |
226 |
227 | 228 | 229 |
230 |

2. Incomplete Evaluation: Missing True Generalization

231 |

232 | Existing Reward Model benchmarks primarily measure how well an RM aligns with a single, predefined preference distribution (often the one it was trained on or a similar one). 233 |

234 |
235 |
236 |

Ignoring Multifaceted Values:

237 |

Human preferences are complex, context-dependent, and multifaceted. A truly useful RM must adapt to any explicitly stated principle, not just echo a single, baked-in preference.

238 |
239 |
240 |

Superficial Alignment:

241 |

This narrow evaluation fails to assess the critical capability of generalizing to diverse and novel principles at inference time, which is essential for robust and trustworthy AI systems.

242 |
243 |
244 |
245 | 246 | 247 |
248 |

Consequences: Why Current Reward Models Fall Short

249 |

These fundamental issues in training and evaluation lead to several critical shortcomings:

250 | 251 |
252 |
253 |
254 | 255 | 256 | 257 | 258 |
259 |

Overfitting to Static Preferences

260 |

RMs master a single, fixed preference from training data, failing to grasp the multifaceted nature of human values or adapt to diverse contexts.

261 |
262 | 263 |
264 |
265 | 266 | 267 | 268 | 269 |
270 |

Opaque & Implicit Reasoning

271 |

Learning from outcomes alone (chosen/rejected pairs), RMs lack an explicit understanding of *why* a response is preferred, making their judgments uninterpretable black boxes.

272 |
273 | 274 |
275 |
276 | 277 | 278 | 279 | 280 |
281 |

Vulnerability to Spurious Correlations

282 |

Implicit learning on biased data leads RMs to mistakenly learn superficial cues (e.g., length, format, specific keywords) as proxies for genuine quality.

283 |
284 | 285 |
286 |
287 | 288 | 289 | 290 | 291 |
292 |

Costly & Inefficient Adaptation

293 |

Due to overfit, static preferences and opaque reasoning, aligning RMs with new criteria or principles demands expensive data collection and full retraining cycles.

294 |
295 |
296 |
297 |
298 |
299 | 300 | 301 |
302 |
303 |
304 |

The Solution: Principle-Following Reward Models

305 |

306 | To overcome these limitations, we propose a paradigm shift towards reward models that explicitly understand and follow natural language principles. This approach enables dynamic adaptation to any evaluation criteria without costly retraining and is embodied by two key innovations: 307 |

308 |
309 | 310 |
311 | 312 |
313 |
314 | 🎯 315 |

1. A New Evaluation Paradigm: RABench

316 |
317 |

318 | Current benchmarks assess how well RMs fit a single, fixed preference. This is insufficient. We argue that, analogous to how Large Language Models (LLMs) are valued for their ability to follow diverse instructions, reward models must be evaluated on their capacity to follow diverse principles. 319 |

320 |

321 | To this end, we introduce RABench (RewardAnything Benchmark). It is a comprehensive benchmark meticulously designed to assess the principle-following capabilities of RMs across various domains (chat, code, safety, math) and a wide array of explicit natural language criteria. 322 |

323 |

324 | RABench moves beyond static preference matching, pushing for RMs that demonstrate true generalization in understanding and applying "goodness" based on varied, explicit guidance. 325 |

326 |
327 | 328 | 329 |
330 |
331 | 🏆 332 |

2. The RewardAnything Model

333 |
334 |

335 | We develop RewardAnything, a novel reward model engineered to embody this principle-following paradigm. 336 |

337 |

338 | Trained using advanced Reinforcement Learning (RL) techniques on principle-conditioned preference data, RewardAnything learns to robustly distinguish better responses from worse ones by directly conditioning on explicit natural language principles provided at inference time. This allows it to adapt its judgment dynamically without any retraining. 339 |

340 |

341 | A key feature is its inference-time reasoning process. RewardAnything not only scores responses according to the given principle but can also articulate an explanation for its judgment, enhancing transparency and trustworthiness. 342 |

343 |
344 |
345 | 346 | 347 |
348 |
349 | Figure 1: Current post-training optimization paradigm vs RewardAnything approach 352 |
353 |
354 |
355 |
356 | 357 | 358 |
359 |
360 |

361 | 📖Dive Deeper into the Details 362 |

363 |

364 | For a comprehensive understanding of our methodology, technical innovations, detailed model architecture, training procedures, and full experimental setup, please refer to our full research paper. The paper provides an in-depth exploration of the concepts presented here. 365 |

366 | 368 | 📄 Read the Full Paper 369 | 370 |
371 |
372 | 373 | 374 |
375 |
376 |
377 |

🚀 Quick Start

378 |

379 | RewardAnything offers three flexible deployment options to fit your workflow, from quick experimentation to production-scale evaluation. 380 |

381 |
382 | 383 | 384 |
385 |
386 |

📦 Installation

387 |
388 | pip install rewardanything 389 |
390 |
391 |
392 | 393 | 394 |
395 | 396 |
398 | 399 |
400 | 401 | Recommended for Beginners 402 | 403 | 🏠 404 |
405 | 406 |

Local Inference

407 |

Perfect for quick experimentation and research

408 | 409 | 410 |
411 | Quick Testing 412 | Research 413 | Offline Use 414 |
415 | 416 |
417 |
418 |
✅ Pros:
419 |
    420 |
  • • Simple one-line setup
  • 421 |
  • • No external dependencies
  • 422 |
  • • Full control & offline capable
  • 423 |
424 |
425 | 426 |
427 |
⚠️ Cons:
428 |
    429 |
  • • Local GPU required (8GB+ VRAM)
  • 430 |
  • • Not ideal for batch processing
  • 431 |
432 |
433 |
434 | 435 |
436 |
437 | View Guides → 438 |
439 |
440 |
441 | 442 | 443 |
445 | 446 |
447 | 448 | Recommended for Production 449 | 450 | 🚀 451 |
452 | 453 |

vLLM Deployment

454 |

Optimized for high-throughput and production

455 | 456 | 457 |
458 | Production 459 | RLHF Training 460 | Batch Processing 461 |
462 | 463 |
464 |
465 |
✅ Pros:
466 |
    467 |
  • • Distributed & concurrent inference
  • 468 |
  • • Production-ready scalability
  • 469 |
  • • Optimized memory usage
  • 470 |
471 |
472 | 473 |
474 |
⚠️ Cons:
475 |
    476 |
  • • vLLM setup required
  • 477 |
  • • More complex configuration
  • 478 |
479 |
480 |
481 | 482 |
483 |
484 | View Guides → 485 |
486 |
487 |
488 | 489 | 490 |
492 | 493 |
494 | 495 | Recommended for Customization 496 | 497 | 🔧 498 |
499 | 500 |

Transformers Direct

501 |

Maximum flexibility for custom workflows

502 | 503 | 504 |
505 | Custom Logic 506 | Research 507 | Low-level Control 508 |
509 | 510 |
511 |
512 |
✅ Pros:
513 |
    514 |
  • • Full model control & access
  • 515 |
  • • Custom processing pipelines
  • 516 |
  • • HuggingFace ecosystem
  • 517 |
518 |
519 | 520 |
521 |
⚠️ Cons:
522 |
    523 |
  • • Manual output parsing
  • 524 |
  • • More boilerplate code
  • 525 |
526 |
527 |
528 | 529 |
530 |
531 | View Guides → 532 |
533 |
534 |
535 |
536 | 537 | 538 |
539 | 540 |
541 |
542 | 🏠 543 |
544 |

Local Inference

545 |

Simple setup for quick testing and research

546 |
547 |
548 |
549 |
import rewardanything
 550 | 
 551 | # Load model locally (similar to HuggingFace)
 552 | reward_model = rewardanything.from_pretrained("zhuohaoyu/RewardAnything-8B-v1", device="cuda")
 553 | 
 554 | # Get comprehensive evaluation
 555 | result = reward_model.judge(
 556 |     principle="I prefer clear, concise and helpful responses over long and detailed ones.",
 557 |     prompt="How do I learn Python programming effectively?", 
 558 |     responses={ # responses with keys, note these are masked and shuffled and then given to RewardAnything to prevent cheating
 559 |         "response_a": "Start with Python.org\\'s tutorial, practice daily with small projects, and join r/learnpython for help. Focus on fundamentals first.",
 560 |         "response_b": "Here\\'s a comprehensive approach: 1) Start with Python basics including variables, data types, operators, control structures like if-statements, for-loops, while-loops, and functions, 2) Practice with small projects like calculators, text games, and data manipulation scripts, 3) Use interactive platforms like Codecademy, Python.org\\'s official tutorial, edX courses, Coursera specializations, and YouTube channels, 4) Join communities like r/learnpython, Stack Overflow, Python Discord servers, and local meetups for support and networking, 5) Build progressively complex projects including web scrapers, APIs, data analysis tools, and web applications, 6) Read books like \\'Automate the Boring Stuff\\', \\'Python Crash Course\\', and \\'Effective Python\\', 7) Dedicate 1-2 hours daily for consistent progress and track your learning journey.",
 561 |         "response_c": "Learn Python by coding."
 562 |     }
 563 | )
 564 | 
 565 | # Access results
 566 | print(f"Scores: {result.scores}")
 567 | print(f"Ranking: {result.ranking}")
 568 | print(f"Reasoning: {result.reasoning}")
569 |
570 |
571 | 572 | 573 | 627 | 628 | 629 | 685 |
686 |
687 |
688 | 689 | 690 |
691 |
692 |
693 |

🔬 Advanced Usage

694 |

695 | Unlock the full potential of RewardAnything by leveraging sophisticated principles and seamlessly integrating it into your RLHF workflows. 696 |

697 |
698 | 699 |
700 | 701 |
702 |
703 | 🧩 704 |

Complex Principles

705 |
706 |

707 | RewardAnything excels when provided with clear, structured principles, especially for nuanced tasks involving multiple, potentially conflicting objectives. Define criteria, assign weights (e.g., via textual emphasis or explicit percentages), and specify priorities to guide the model's judgment effectively. This allows for fine-grained control over the evaluation process. 708 |

709 |
710 |
# Define a detailed, multi-faceted principle
 711 | complex_principle = """
 712 | Safety comes first but also be sure not to encourage
 713 | overly sensitive reiections for safe or benignly
 714 | borderline queries. Next, equally value warmth,
 715 | appropriate humor (to deflect borderline harm), 
 716 | and genuine helpfulness. Remember, content and tone 
 717 | are more important than presentation style.
 718 | 
 719 | """
 720 | 
 721 | # Assume 'reward_model' is initialized
 722 | # prompt = "Your specific prompt here"
 723 | # responses = {"res_a": "...", "res_b": "..."}
 724 | result = reward_model.judge(
 725 |     principle=complex_principle,
 726 |     prompt=prompt,
 727 |     responses=responses
 728 | )
729 |
730 |
731 | 732 | 733 |
734 |
735 | 🔄 736 |

RLHF Integration

737 |
738 |

739 | Seamlessly integrate RewardAnything into your Reinforcement Learning from Human Feedback (RLHF) pipelines. It can serve as a dynamic, principle-driven reward function. RewardAnything is compatible with popular RL frameworks (e.g., TRL, veRL), allowing you to guide model generation based on explicit criteria rather than static preferences. 740 |

741 |

742 | Detailed integration examples and best practices can be found in our official repository. 743 |

744 |
745 |
# Example: Use in a PPO-style training loop
 746 | # Assume 'reward_model' is initialized
 747 | # principle = "Your guiding principle"
 748 | # prompt = "The input prompt"
 749 | 
 750 | def reward_function(principle, prompt, response_text):
 751 |     eval_responses = {"generated": response_text}
 752 |     result = reward_model.judge(
 753 |         principle=principle,
 754 |         prompt=prompt,
 755 |         responses=eval_responses
 756 |     )
 757 |     return result.scores.get("generated", 0.0)
 758 | 
 759 | # generated_responses = ["response1", "response2", ...]
 760 | rewards = [reward_function(principle, prompt, resp) 
 761 |            for resp in generated_responses]
762 |
763 |
764 |
765 |
766 |
767 | 768 | 769 |
770 |
771 |
772 |

State-of-the-Art Performance

773 |

774 | RewardAnything achieves excellent performance on both traditional benchmarks and our new principle-following evaluation. Below are highlights from RM-Bench and our proposed RABench. For full details, additional benchmarks, and ablation studies, please see our paper. 775 |

776 |
777 | 778 |
779 |
780 |

Table 2: Performance on RM-Bench

781 | Table 2: Accuracies (%) of reward models on RM-Bench 784 |
785 | 786 |
787 |

Table 3: Performance on RABench (Ours)

788 | Table 3: Performance of reward models on RABench 791 |
792 |
793 |
794 |
795 | 796 | 797 |
798 |
799 |
800 |

Key Innovations

801 |

802 | RewardAnything introduces novel techniques for principle-following reward modeling 803 |

804 |
805 | 806 |
807 |
808 |
809 |
810 |
811 | 812 | 813 | 814 |
815 |
816 |

Group Relative Policy Optimization (GRPO)

817 |

Advanced RL training that learns relative preferences within response groups

818 |
819 |
820 |
821 |
822 | 823 | 824 | 825 |
826 |
827 |

Listwise Evaluation

828 |

Efficient ranking of multiple responses in a single forward pass

829 |
830 |
831 |
832 |
833 | 834 | 835 | 836 |
837 |
838 |

Inference-Time Reasoning

839 |

Explicit reasoning process for transparent decision making

840 |
841 |
842 |
843 |
844 |
845 |
846 | 847 | 848 | 849 |
850 |
851 |

Multi-LLM Consensus

852 |

Ground truth from 4 state-of-the-art LLMs with algorithmic consensus

853 |
854 |
855 |
856 |
857 | 858 | 859 | 860 |
861 |
862 |

Human Verification

863 |

89% agreement rate with κ=0.57 for reliable evaluation standards

864 |
865 |
866 |
867 |
868 | 869 | 870 |
871 |

RABench: Novel Evaluation Framework

872 |

873 | We introduce RABench, a comprehensive benchmark specifically designed to evaluate reward models' 874 | ability to follow explicit natural language principles across diverse domains and criteria. 875 |

876 |
877 |
878 |
879 | 1,002 validated rankings across 50 principles 880 |
881 |
882 |
883 | 5 principle categories: Content, Logic, Style, Tone, Structure 884 |
885 |
886 |
887 | Multiple domains: Chat, Code, Safety, Math 888 |
889 |
890 |
891 | Human-verified quality with high inter-annotator agreement 892 |
893 |
894 |
895 |
896 |
897 |
898 | 899 | 900 |
901 |
902 |
903 |

Documentation & Resources

904 |

905 | Everything you need to understand and use RewardAnything for your research and applications 906 |

907 |
908 | 909 | 934 |
935 |
936 | 937 | 938 |
939 |
940 |
941 |

Citation

942 |

943 | If you use RewardAnything in your research, please cite our paper 944 |

945 |
946 | 947 |
948 |
949 |
@article{yu2025rewardanything,
 950 |   title={RewardAnything: Generalizable Principle-Following Reward Models},
 951 |   author={Yu, Zhuohao and Zeng, Jiali and Gu, Weizheng and Wang, Yidong and 
 952 |           Wang, Jindong and Meng, Fandong and Zhou, Jie and Zhang, Yue and 
 953 |           Zhang, Shikun and Ye, Wei},
 954 |   journal={arXiv preprint arXiv:2506.03637},
 955 |   year={2025}
 956 | }
957 |
958 |
959 |
960 |
961 | 962 | 992 | 993 | --------------------------------------------------------------------------------