├── Tiny Tool Use ├── src │ ├── data │ │ ├── __init__.py │ │ └── data_generator.py │ ├── models │ │ ├── __init__.py │ │ └── tool_aware_model.py │ ├── tools │ │ ├── __init__.py │ │ └── executor.py │ ├── training │ │ ├── __init__.py │ │ └── trainer.py │ ├── utils │ │ ├── __init__.py │ │ ├── logging_utils.py │ │ ├── config.py │ │ └── evaluation.py │ └── __init__.py ├── requirements.txt ├── configs │ ├── sft_toolbench_config.json │ ├── teacher_mode_config.json │ └── dpo_config.json ├── LICENSE ├── setup.sh ├── examples │ ├── test_cases.json │ └── run_examples.py ├── evaluate.py ├── validate.py ├── tests │ └── test_basic.py ├── save_merge_model.py ├── .gitignore ├── train.py ├── README.md └── demo.py ├── logo.png └── README.md /Tiny Tool Use/src/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Tiny Tool Use/src/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Tiny Tool Use/src/tools/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Tiny Tool Use/src/training/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Tiny Tool Use/src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bageldotcom/bagel-RL/HEAD/logo.png -------------------------------------------------------------------------------- /Tiny Tool Use/src/__init__.py: -------------------------------------------------------------------------------- 1 | """LLM Tool Use Training Playground""" 2 | 3 | __version__ = "0.1.0" 4 | -------------------------------------------------------------------------------- /Tiny Tool Use/src/utils/logging_utils.py: -------------------------------------------------------------------------------- 1 | """Logging utilities.""" 2 | 3 | import logging 4 | from rich.logging import RichHandler 5 | 6 | 7 | def setup_logging(debug: bool = False): 8 | """Setup logging configuration with rich formatting.""" 9 | level = logging.DEBUG if debug else logging.INFO 10 | 11 | logging.basicConfig( 12 | level=level, 13 | format="%(message)s", 14 | datefmt="[%X]", 15 | handlers=[RichHandler(rich_tracebacks=True)] 16 | ) 17 | 18 | # Reduce noise from transformers and other libraries 19 | logging.getLogger("transformers").setLevel(logging.WARNING) 20 | logging.getLogger("datasets").setLevel(logging.WARNING) 21 | logging.getLogger("urllib3").setLevel(logging.WARNING) 22 | -------------------------------------------------------------------------------- /Tiny Tool Use/requirements.txt: -------------------------------------------------------------------------------- 1 | # Core ML/RL dependencies 2 | torch>=2.0.0 3 | transformers>=4.30.0 4 | datasets>=2.12.0 5 | accelerate>=0.20.0 6 | peft>=0.4.0 7 | trl>=0.7.0 8 | 9 | # Tool handling and API 10 | openai>=1.0.0 11 | anthropic>=0.7.0 12 | requests>=2.31.0 13 | jsonschema>=4.17.0 14 | 15 | # Data processing 16 | pandas>=2.0.0 17 | numpy>=1.24.0 18 | scikit-learn>=1.3.0 19 | gdown>=5.2.0 20 | 21 | # Evaluation and logging 22 | tensorboard>=2.13.0 23 | rouge-score>=0.1.2 24 | bert-score>=0.3.13 25 | bfcl>=1.0.1 26 | 27 | # Utilities 28 | tqdm>=4.65.0 29 | pyyaml>=6.0 30 | click>=8.1.0 31 | rich>=13.0.0 32 | 33 | # Optional: For paraphrasing models 34 | sentence-transformers>=2.2.0 35 | nltk>=3.8.0 36 | 37 | # Development 38 | pytest>=7.0.0 39 | black>=23.0.0 40 | flake8>=6.0.0 -------------------------------------------------------------------------------- /Tiny Tool Use/configs/sft_toolbench_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": { 3 | "name": "Qwen/Qwen3-0.6B", 4 | "trust_remote_code": true, 5 | "torch_dtype": "float16", 6 | "device_map": "auto" 7 | }, 8 | "training": { 9 | "method": "sft", 10 | "num_epochs": 1, 11 | "learning_rate": 5e-5, 12 | "batch_size": 4, 13 | "gradient_accumulation_steps": 8, 14 | "tokenizer_padding_side": "left", 15 | "warmup_steps": 100, 16 | "max_length": 2048, 17 | "use_lora": true, 18 | "lora_r": 8, 19 | "lora_alpha": 32, 20 | "lora_dropout": 0.05 21 | 22 | }, 23 | "data": { 24 | "strategy": "toolbench", 25 | "generation_type": "real", 26 | "max_samples": 700, 27 | "train_split": 0.99 28 | }, 29 | 30 | "tokenizer":{ 31 | "name":"Qwen/Qwen3-0.6B", 32 | "trust_remote_code": true 33 | }, 34 | 35 | "tools": 36 | { 37 | 38 | } 39 | 40 | 41 | } 42 | -------------------------------------------------------------------------------- /Tiny Tool Use/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Tiny Tool Use 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /Tiny Tool Use/setup.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # LLM Tool Use Training Playground Setup Script 4 | 5 | set -e 6 | 7 | echo "🚀 Setting up LLM Tool Use Training Playground..." 8 | 9 | # Check if Python is available 10 | if ! command -v python3 &> /dev/null; then 11 | echo "❌ Python 3 is required but not installed." 12 | exit 1 13 | fi 14 | 15 | # Create virtual environment 16 | echo "📦 Creating virtual environment..." 17 | python3 -m venv venv 18 | 19 | # Activate virtual environment 20 | echo "🔄 Activating virtual environment..." 21 | source venv/bin/activate 22 | 23 | # Upgrade pip 24 | echo "⬆️ Upgrading pip..." 25 | pip install --upgrade pip 26 | 27 | # Install requirements 28 | echo "📚 Installing requirements..." 29 | pip install -r requirements.txt 30 | 31 | # Create output directories 32 | echo "📁 Creating output directories..." 33 | mkdir -p outputs logs 34 | 35 | echo "✅ Setup complete!" 36 | echo "" 37 | echo "To get started:" 38 | echo "1. Activate the virtual environment: source venv/bin/activate" 39 | echo "2. Run a training example: python train.py --config configs/calculator_config.json" 40 | echo "3. Or run the interactive examples: python examples/run_examples.py" 41 | echo "" 42 | echo "For more information, see README.md" 43 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |
3 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
Open-source RL for distributed learning systems by Bagel Labs
21 | 22 | --- 23 | 24 | ## Overview 25 | 26 | This repository is a collection of open-source distributed and non-distributed reinforcement learning projects by Bagel Labs. 27 | 28 | --- 29 | 30 | ## Projects 31 | 32 | * [**Tiny Tool Use**](Tiny%20Tool%20Use/) – an intentionally-tiny yet production-ready library for fine-tuning LLMs to make robust, auditable tool calls. 33 | * More projects coming soon — give the repo a star to receive updates. 34 | 35 | --- 36 | 37 | ## License 38 | 39 | See individual sub-project directories for license details. 40 | -------------------------------------------------------------------------------- /Tiny Tool Use/examples/test_cases.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "prompt": "Human: What is 25 multiplied by 4?", 4 | "expected_tool": "calculator", 5 | "expected_params": {"expression": "25 * 4"} 6 | }, 7 | { 8 | "prompt": "Human: I need to compute 15 plus 30 minus 8", 9 | "expected_tool": "calculator", 10 | "expected_params": {"expression": "15 + 30 - 8"} 11 | }, 12 | { 13 | "prompt": "Human: What's the weather forecast for New York City?", 14 | "expected_tool": "weather", 15 | "expected_params": {"location": "New York City"} 16 | }, 17 | { 18 | "prompt": "Human: Can you check the weather in London?", 19 | "expected_tool": "weather", 20 | "expected_params": {"location": "London"} 21 | }, 22 | { 23 | "prompt": "Human: Search for Python machine learning tutorials", 24 | "expected_tool": "search", 25 | "expected_params": {"query": "Python machine learning tutorials"} 26 | }, 27 | { 28 | "prompt": "Human: Find information about reinforcement learning", 29 | "expected_tool": "search", 30 | "expected_params": {"query": "reinforcement learning"} 31 | }, 32 | { 33 | "prompt": "Human: Calculate the square root of 144", 34 | "expected_tool": "calculator", 35 | "expected_params": {"expression": "144 ** 0.5"} 36 | }, 37 | { 38 | "prompt": "Human: What's the temperature in Tokyo right now?", 39 | "expected_tool": "weather", 40 | "expected_params": {"location": "Tokyo"} 41 | }, 42 | { 43 | "prompt": "Human: Look up information about neural networks", 44 | "expected_tool": "search", 45 | "expected_params": {"query": "neural networks"} 46 | }, 47 | { 48 | "prompt": "Human: Compute 2 to the power of 8", 49 | "expected_tool": "calculator", 50 | "expected_params": {"expression": "2 ** 8"} 51 | } 52 | ] 53 | -------------------------------------------------------------------------------- /Tiny Tool Use/examples/run_examples.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Example usage script demonstrating different training modes. 4 | """ 5 | 6 | import subprocess 7 | import sys 8 | from pathlib import Path 9 | 10 | 11 | def run_training_example(config_name: str): 12 | """Run a training example.""" 13 | config_path = f"configs/{config_name}" 14 | 15 | if not Path(config_path).exists(): 16 | print(f"Configuration file {config_path} not found!") 17 | return False 18 | 19 | print(f"\n🚀 Running training with {config_name}...") 20 | 21 | try: 22 | result = subprocess.run([ 23 | sys.executable, "train.py", 24 | "--config", config_path, 25 | "--output-dir", f"outputs/{config_name.replace('.json', '')}" 26 | ], check=True, capture_output=True, text=True) 27 | 28 | print("✅ Training completed successfully!") 29 | print(result.stdout) 30 | return True 31 | 32 | except subprocess.CalledProcessError as e: 33 | print(f"❌ Training failed: {e}") 34 | print(f"Error output: {e.stderr}") 35 | return False 36 | 37 | 38 | def main(): 39 | print("🎯 LLM Tool Use Training Playground - Example Usage") 40 | print("=" * 50) 41 | 42 | examples = [ 43 | ("dpo_config.json", "DPO Training with manual templates"), 44 | ("sft_toolbech_config.json", "Supervised Fine-tuning with Toolbench Data") 45 | 46 | ] 47 | 48 | for config_file, description in examples: 49 | print(f"\n📋 Example: {description}") 50 | print(f"Config: {config_file}") 51 | 52 | response = input("Run this example? (y/n): ").lower().strip() 53 | if response == 'y': 54 | success = run_training_example(config_file) 55 | if success: 56 | print("You can now evaluate the model with:") 57 | print(f"python evaluate.py --model-path outputs/{config_file.replace('.json', '')} --config configs/{config_file}") 58 | else: 59 | print("Skipping...") 60 | 61 | 62 | if __name__ == "__main__": 63 | main() 64 | -------------------------------------------------------------------------------- /Tiny Tool Use/evaluate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Evaluation script for trained tool use models. 4 | """ 5 | 6 | import argparse 7 | import json 8 | import logging 9 | 10 | from src.utils.evaluation import ToolUseEvaluator, create_test_cases 11 | from src.utils.config import ConfigManager 12 | from src.utils.logging_utils import setup_logging 13 | 14 | 15 | def main(): 16 | parser = argparse.ArgumentParser(description="Evaluate trained tool use model") 17 | parser.add_argument( 18 | "--model-path", 19 | type=str, 20 | required=True, 21 | help="Path to trained model" 22 | ) 23 | parser.add_argument( 24 | "--config", 25 | type=str, 26 | required=True, 27 | help="Path to configuration file" 28 | ) 29 | parser.add_argument( 30 | "--output", 31 | type=str, 32 | default="evaluation_results.json", 33 | help="Output file for results" 34 | ) 35 | parser.add_argument( 36 | "--test-cases", 37 | type=str, 38 | default=None, 39 | help="Custom test cases JSON file" 40 | ) 41 | 42 | args = parser.parse_args() 43 | 44 | setup_logging() 45 | logger = logging.getLogger(__name__) 46 | 47 | # Load configuration 48 | config_manager = ConfigManager(args.config) 49 | config = config_manager.get_config() 50 | 51 | # Initialize evaluator 52 | evaluator = ToolUseEvaluator(args.model_path, config["tools"]) 53 | 54 | # Load test cases 55 | if args.test_cases: 56 | with open(args.test_cases, 'r') as f: 57 | test_cases = json.load(f) 58 | else: 59 | test_cases = create_test_cases() 60 | 61 | logger.info(f"Evaluating model on {len(test_cases)} test cases...") 62 | 63 | # Run evaluation 64 | results = evaluator.evaluate_specific_cases(test_cases) 65 | 66 | # Save results 67 | with open(args.output, 'w') as f: 68 | json.dump(results, f, indent=2) 69 | 70 | # Print summary 71 | print("\nEvaluation Results:") 72 | print(f"Tool Accuracy: {results['tool_accuracy']:.3f}") 73 | print(f"Format Accuracy: {results['format_accuracy']:.3f}") 74 | print(f"Execution Accuracy: {results['execution_accuracy']:.3f}") 75 | print(f"\nDetailed results saved to: {args.output}") 76 | 77 | 78 | if __name__ == "__main__": 79 | main() 80 | -------------------------------------------------------------------------------- /Tiny Tool Use/validate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Quick validation script to check if the playground is set up correctly. 4 | """ 5 | 6 | import json 7 | import sys 8 | from pathlib import Path 9 | 10 | def check_project_structure(): 11 | """Check if all required files are present.""" 12 | required_files = [ 13 | "requirements.txt", 14 | "train.py", 15 | "evaluate.py", 16 | "configs/calculator_config.json", 17 | "configs/ppo_config.json", 18 | "src/tools/executor.py", 19 | "src/data/data_generator.py", 20 | "src/training/trainer.py", 21 | "examples/test_cases.json" 22 | ] 23 | 24 | missing_files = [] 25 | for file_path in required_files: 26 | if not Path(file_path).exists(): 27 | missing_files.append(file_path) 28 | 29 | if missing_files: 30 | print(f"❌ Missing files: {missing_files}") 31 | return False 32 | else: 33 | print("✅ All required files present") 34 | return True 35 | 36 | 37 | def check_configurations(): 38 | """Check if configuration files are valid JSON.""" 39 | config_files = [ 40 | "configs/calculator_config.json", 41 | "configs/ppo_config.json" 42 | ] 43 | 44 | for config_file in config_files: 45 | try: 46 | with open(config_file, 'r') as f: 47 | json.load(f) 48 | print(f"✅ {config_file} is valid JSON") 49 | except Exception as e: 50 | print(f"❌ {config_file} has invalid JSON: {e}") 51 | return False 52 | 53 | return True 54 | 55 | 56 | def main(): 57 | print("🔍 LLM Tool Use Training Playground - Validation Check") 58 | print("=" * 55) 59 | 60 | checks = [ 61 | ("Project Structure", check_project_structure), 62 | ("Configuration Files", check_configurations) 63 | ] 64 | 65 | all_passed = True 66 | for check_name, check_func in checks: 67 | print(f"\n📋 Checking {check_name}...") 68 | if not check_func(): 69 | all_passed = False 70 | 71 | print("\n" + "=" * 55) 72 | if all_passed: 73 | print("🎉 All checks passed! The playground is ready to use.") 74 | print("\nNext steps:") 75 | print("1. Run: ./setup.sh (to install dependencies)") 76 | print("2. Run: python train.py --config configs/calculator_config.json") 77 | print("3. Or try: python examples/run_examples.py") 78 | else: 79 | print("❌ Some checks failed. Please review the errors above.") 80 | sys.exit(1) 81 | 82 | 83 | if __name__ == "__main__": 84 | main() 85 | -------------------------------------------------------------------------------- /Tiny Tool Use/configs/teacher_mode_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": { 3 | "name": "Qwen/Qwen3-0.6B", 4 | "trust_remote_code": false, 5 | "torch_dtype": "float16", 6 | "device_map": "auto" 7 | }, 8 | 9 | "tokenizer":{ 10 | "name":"Qwen/Qwen3-0.6B", 11 | "trust_remote_code": true 12 | }, 13 | 14 | "training": { 15 | "method": "teacher_mode", 16 | "num_epochs": 2, 17 | "learning_rate": 3e-5, 18 | "batch_size": 2, 19 | "gradient_accumulation_steps": 2, 20 | "warmup_steps": 50, 21 | "max_length": 384, 22 | "use_lora": true, 23 | "lora_r": 8, 24 | "lora_alpha": 16, 25 | "lora_dropout": 0.1 26 | }, 27 | "data": { 28 | "strategy": "teacher_mode", 29 | "generation_type":"synthetic", 30 | "max_samples": 150, 31 | "train_split": 0.85 32 | }, 33 | "tools": [ 34 | { 35 | "name": "calculator", 36 | "description": "Perform mathematical calculations including basic arithmetic and simple expressions", 37 | "type": "function", 38 | "function": "calculator", 39 | "parameters": { 40 | "type": "object", 41 | "properties": { 42 | "expression": { 43 | "type": "string", 44 | "description": "Mathematical expression to evaluate (supports +, -, *, /, **, parentheses)" 45 | } 46 | }, 47 | "required": ["expression"] 48 | } 49 | }, 50 | { 51 | "name": "weather", 52 | "description": "Get current weather information for any location worldwide", 53 | "type": "function", 54 | "function": "weather", 55 | "parameters": { 56 | "type": "object", 57 | "properties": { 58 | "location": { 59 | "type": "string", 60 | "description": "City name, state/country (e.g., 'Paris', 'New York, NY', 'Tokyo, Japan')" 61 | } 62 | }, 63 | "required": ["location"] 64 | } 65 | }, 66 | { 67 | "name": "search", 68 | "description": "Search for information on the internet about any topic", 69 | "type": "function", 70 | "function": "search", 71 | "parameters": { 72 | "type": "object", 73 | "properties": { 74 | "query": { 75 | "type": "string", 76 | "description": "Search query or topic to find information about" 77 | } 78 | }, 79 | "required": ["query"] 80 | } 81 | } 82 | ], 83 | "evaluation": { 84 | "metrics": ["tool_accuracy", "format_correctness", "execution_success", "response_quality"], 85 | "eval_steps": 25 86 | }, 87 | "tensorboard": { 88 | "enabled": true, 89 | "log_dir": "outputs/runs" 90 | }, 91 | "seed": 42 92 | } 93 | -------------------------------------------------------------------------------- /Tiny Tool Use/configs/dpo_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": { 3 | "name": "Qwen/Qwen3-0.6B", 4 | "trust_remote_code": true, 5 | "torch_dtype": "float16", 6 | "device_map": "auto" 7 | }, 8 | 9 | "tokenizer":{ 10 | "name":"Qwen/Qwen3-0.6B", 11 | "trust_remote_code": true, 12 | "padding_side": "left" 13 | }, 14 | "training": { 15 | "method": "dpo", 16 | "num_epochs": 100, 17 | "learning_rate": 1e-6, 18 | "batch_size": 2, 19 | "gradient_accumulation_steps": 16, 20 | "warmup_steps": 100, 21 | "max_length": 512, 22 | "use_lora": true, 23 | "lora_r": 16, 24 | "lora_alpha": 32, 25 | "lora_dropout": 0.1, 26 | "max_grad_norm": 0.3, 27 | "bf16": false, 28 | "fp16": true, 29 | "adam_beta1": 0.9, 30 | "adam_beta2": 0.999, 31 | "adam_epsilon": 1e-8, 32 | "weight_decay": 0.01, 33 | "optim": "paged_adamw_8bit" 34 | }, 35 | "data": { 36 | "strategy": "manual_templates", 37 | "generation_type": "synthetic", 38 | "max_samples": 1000, 39 | "train_split": 0.8 40 | }, 41 | "tools": [ 42 | { 43 | "name": "calculator", 44 | "description": "Perform mathematical calculations", 45 | "type": "function", 46 | "function": "calculator", 47 | "parameters": { 48 | "type": "object", 49 | "properties": { 50 | "expression": { 51 | "type": "string", 52 | "description": "Mathematical expression to evaluate" 53 | } 54 | }, 55 | "required": ["expression"] 56 | } 57 | }, 58 | { 59 | "name": "weather", 60 | "description": "Get weather information for a location", 61 | "type": "function", 62 | "function": "weather", 63 | "parameters": { 64 | "type": "object", 65 | "properties": { 66 | "location": { 67 | "type": "string", 68 | "description": "Location to get weather for" 69 | } 70 | }, 71 | "required": ["location"] 72 | } 73 | }, 74 | { 75 | "name": "search", 76 | "description": "Search for information on the internet", 77 | "type": "function", 78 | "function": "search", 79 | "parameters": { 80 | "type": "object", 81 | "properties": { 82 | "query": { 83 | "type": "string", 84 | "description": "Search query" 85 | } 86 | }, 87 | "required": ["query"] 88 | } 89 | } 90 | ], 91 | "evaluation": { 92 | "metrics": ["tool_accuracy", "response_quality", "tool_format_correctness"], 93 | "eval_steps": 100 94 | }, 95 | "tensorboard": { 96 | "enabled": true, 97 | "log_dir": "outputs/runs" 98 | }, 99 | "seed": 42 100 | } 101 | -------------------------------------------------------------------------------- /Tiny Tool Use/tests/test_basic.py: -------------------------------------------------------------------------------- 1 | """Tests for the tool training playground.""" 2 | 3 | import json 4 | import sys 5 | from pathlib import Path 6 | 7 | # Add src to path 8 | sys.path.insert(0, str(Path(__file__).parent.parent / "src")) 9 | 10 | from tools.executor import ToolExecutor 11 | from data.data_generator import DataGenerator 12 | from utils.config import ConfigManager 13 | 14 | 15 | def test_tool_executor(): 16 | """Test basic tool execution.""" 17 | tools_config = [ 18 | { 19 | "name": "calculator", 20 | "description": "Perform calculations", 21 | "type": "function", 22 | "function": "calculator", 23 | "parameters": {"type": "object", "properties": {"expression": {"type": "string"}}} 24 | } 25 | ] 26 | 27 | executor = ToolExecutor(tools_config) 28 | 29 | # Test calculator 30 | result = executor.execute_tool("calculator", {"expression": "2 + 3"}) 31 | assert "result" in result 32 | assert result["result"] == 5 33 | 34 | # Test invalid tool 35 | result = executor.execute_tool("nonexistent", {}) 36 | assert "error" in result 37 | 38 | 39 | def test_data_generator(): 40 | """Test data generation.""" 41 | data_config = { 42 | "strategy": "manual_templates", 43 | "max_samples": 10, 44 | "train_split": 0.8 45 | } 46 | 47 | tools_config = [ 48 | { 49 | "name": "calculator", 50 | "description": "Perform calculations", 51 | "type": "function", 52 | "function": "calculator", 53 | "parameters": {"type": "object", "properties": {"expression": {"type": "string"}}} 54 | } 55 | ] 56 | 57 | generator = DataGenerator(data_config, tools_config) 58 | train_dataset, eval_dataset = generator.prepare_datasets() 59 | 60 | assert len(train_dataset) > 0 61 | assert len(eval_dataset) > 0 62 | assert "text" in train_dataset[0] 63 | 64 | 65 | def test_config_validation(): 66 | """Test configuration validation.""" 67 | # Create a temporary config file 68 | config_data = { 69 | "model": {"name": "test-model"}, 70 | "training": {"method": "sft"}, 71 | "data": {"strategy": "manual_templates"}, 72 | "tools": [{"name": "test", "description": "test", "parameters": {}}] 73 | } 74 | 75 | config_file = Path("/tmp/test_config.json") 76 | with open(config_file, 'w') as f: 77 | json.dump(config_data, f) 78 | 79 | try: 80 | config_manager = ConfigManager(str(config_file)) 81 | config = config_manager.get_config() 82 | assert config["model"]["name"] == "test-model" 83 | finally: 84 | config_file.unlink(missing_ok=True) 85 | 86 | 87 | if __name__ == "__main__": 88 | # Run basic tests 89 | test_tool_executor() 90 | test_data_generator() 91 | test_config_validation() 92 | print("✅ All tests passed!") 93 | -------------------------------------------------------------------------------- /Tiny Tool Use/save_merge_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | from pathlib import Path 5 | from transformers import AutoTokenizer, AutoModelForCausalLM 6 | from peft import PeftModel 7 | 8 | def parse_arguments(): 9 | """Parse command line arguments.""" 10 | parser = argparse.ArgumentParser( 11 | description="Merge LoRA adapters with a base model and save the result." 12 | ) 13 | 14 | parser.add_argument( 15 | "--base_model", 16 | type=str, 17 | required=True, 18 | help="Path or name of the base model to merge adapters with" 19 | ) 20 | 21 | parser.add_argument( 22 | "--adapter_path", 23 | type=str, 24 | required=True, 25 | help="Directory containing the adapter weights" 26 | ) 27 | 28 | parser.add_argument( 29 | "--output_dir", 30 | type=str, 31 | required=True, 32 | help="Directory where to save the merged model" 33 | ) 34 | 35 | parser.add_argument( 36 | "--half_precision", 37 | action="store_true", 38 | help="Save the merged model in half precision (float16)" 39 | ) 40 | 41 | parser.add_argument( 42 | "--trust_remote_code", 43 | action="store_true", 44 | help="Allow loading remote code for tokenizer and model" 45 | ) 46 | 47 | parser.add_argument( 48 | "--no_safetensors", 49 | action="store_true", 50 | help="Don't use safetensors format for saving" 51 | ) 52 | 53 | return parser.parse_args() 54 | 55 | def main(): 56 | # Parse command line arguments 57 | args = parse_arguments() 58 | 59 | print(f"Loading base model: {args.base_model}") 60 | base_model = AutoModelForCausalLM.from_pretrained( 61 | args.base_model, 62 | torch_dtype=torch.float16 if args.half_precision else torch.float32, 63 | low_cpu_mem_usage=True, 64 | trust_remote_code=args.trust_remote_code, 65 | ) 66 | 67 | print(f"Loading adapter from: {args.adapter_path}") 68 | model = PeftModel.from_pretrained(base_model, args.adapter_path) 69 | 70 | # Merge adapter weights with base model 71 | print("Merging adapter with base model...") 72 | model = model.merge_and_unload() 73 | 74 | # Ensure output directory exists 75 | output_dir = Path(args.output_dir) 76 | output_dir.mkdir(parents=True, exist_ok=True) 77 | 78 | # Save the merged model 79 | print(f"Saving merged model to {args.output_dir}") 80 | model.save_pretrained( 81 | args.output_dir, 82 | safe_serialization=not args.no_safetensors 83 | ) 84 | 85 | # Save tokenizer 86 | print("Saving tokenizer...") 87 | tokenizer = AutoTokenizer.from_pretrained( 88 | args.base_model, 89 | trust_remote_code=args.trust_remote_code 90 | ) 91 | tokenizer.save_pretrained(args.output_dir) 92 | 93 | print(f"✓ Model successfully merged and saved to {args.output_dir}") 94 | 95 | if __name__ == "__main__": 96 | main() 97 | 98 | 99 | -------------------------------------------------------------------------------- /Tiny Tool Use/src/models/tool_aware_model.py: -------------------------------------------------------------------------------- 1 | """Model utilities and custom architectures.""" 2 | 3 | import torch 4 | import torch.nn as nn 5 | from transformers import AutoModel, AutoConfig 6 | from typing import Dict, Any 7 | 8 | 9 | class ToolAwareModel(nn.Module): 10 | """A model wrapper that's aware of tool calling patterns.""" 11 | 12 | def __init__(self, base_model, config: Dict[str, Any]): 13 | super().__init__() 14 | self.base_model = base_model 15 | self.config = config 16 | 17 | # Tool detection head 18 | self.tool_detector = nn.Linear( 19 | base_model.config.hidden_size, 20 | len(config.get("tools", [])) 21 | ) 22 | 23 | # Tool confidence head 24 | self.confidence_head = nn.Linear( 25 | base_model.config.hidden_size, 26 | 1 27 | ) 28 | 29 | def forward(self, input_ids, attention_mask=None, labels=None, **kwargs): 30 | """Forward pass with tool awareness.""" 31 | outputs = self.base_model( 32 | input_ids=input_ids, 33 | attention_mask=attention_mask, 34 | labels=labels, 35 | **kwargs 36 | ) 37 | 38 | # Get last hidden state for tool detection 39 | last_hidden_state = outputs.hidden_states[-1] if hasattr(outputs, 'hidden_states') else None 40 | 41 | if last_hidden_state is not None: 42 | # Tool detection logits 43 | tool_logits = self.tool_detector(last_hidden_state[:, -1, :]) 44 | 45 | # Tool confidence 46 | confidence = torch.sigmoid(self.confidence_head(last_hidden_state[:, -1, :])) 47 | 48 | # Add to outputs 49 | outputs.tool_logits = tool_logits 50 | outputs.tool_confidence = confidence 51 | 52 | return outputs 53 | 54 | 55 | class RewardModel(nn.Module): 56 | """Reward model for tool use success.""" 57 | 58 | def __init__(self, base_model_name: str, num_tools: int): 59 | super().__init__() 60 | 61 | # Load base model 62 | config = AutoConfig.from_pretrained(base_model_name) 63 | self.base_model = AutoModel.from_pretrained(base_model_name, config=config) 64 | 65 | # Reward head 66 | self.reward_head = nn.Sequential( 67 | nn.Linear(config.hidden_size, config.hidden_size // 2), 68 | nn.ReLU(), 69 | nn.Dropout(0.1), 70 | nn.Linear(config.hidden_size // 2, 1) 71 | ) 72 | 73 | # Tool-specific reward heads 74 | self.tool_reward_heads = nn.ModuleList([ 75 | nn.Linear(config.hidden_size, 1) for _ in range(num_tools) 76 | ]) 77 | 78 | def forward(self, input_ids, attention_mask=None, tool_id=None): 79 | """Forward pass for reward calculation.""" 80 | outputs = self.base_model( 81 | input_ids=input_ids, 82 | attention_mask=attention_mask 83 | ) 84 | 85 | # Get pooled representation 86 | pooled_output = outputs.last_hidden_state.mean(dim=1) 87 | 88 | # General reward 89 | reward = self.reward_head(pooled_output) 90 | 91 | # Tool-specific reward if tool_id provided 92 | tool_reward = None 93 | if tool_id is not None and 0 <= tool_id < len(self.tool_reward_heads): 94 | tool_reward = self.tool_reward_heads[tool_id](pooled_output) 95 | 96 | return { 97 | "reward": reward, 98 | "tool_reward": tool_reward, 99 | "pooled_output": pooled_output 100 | } 101 | -------------------------------------------------------------------------------- /Tiny Tool Use/src/utils/config.py: -------------------------------------------------------------------------------- 1 | """Configuration management utilities.""" 2 | 3 | import json 4 | import jsonschema 5 | from pathlib import Path 6 | from typing import Dict, Any 7 | 8 | 9 | class ConfigManager: 10 | """Manages configuration loading and validation.""" 11 | 12 | def __init__(self, config_path: str): 13 | self.config_path = Path(config_path) 14 | self.config = self._load_config() 15 | # do not run the configuration validation for toolbench 16 | if not self.config["data"]["strategy"].lower()=="toolbench": 17 | self._validate_config() 18 | 19 | def _load_config(self) -> Dict[str, Any]: 20 | """Load configuration from JSON file.""" 21 | with open(self.config_path, 'r') as f: 22 | return json.load(f) 23 | 24 | def _validate_config(self): 25 | """Validate configuration against schema.""" 26 | schema = { 27 | "type": "object", 28 | "required": ["model", "training", "data", "tools"], 29 | "properties": { 30 | "model": { 31 | "type": "object", 32 | "required": ["name"], 33 | "properties": { 34 | "name": {"type": "string"}, 35 | "trust_remote_code": {"type": "boolean"}, 36 | "torch_dtype": {"type": "string"}, 37 | "device_map": {"type": "string"} 38 | } 39 | }, 40 | "training": { 41 | "type": "object", 42 | "required": ["method"], 43 | "properties": { 44 | "method": {"enum": ["sft", "ppo", "dpo", "teacher_mode"]}, 45 | "num_epochs": {"type": "integer", "minimum": 1}, 46 | "learning_rate": {"type": "number", "minimum": 0}, 47 | "batch_size": {"type": "integer", "minimum": 1}, 48 | "gradient_accumulation_steps": {"type": "integer", "minimum": 1}, 49 | "warmup_steps": {"type": "integer", "minimum": 0}, 50 | "max_length": {"type": "integer", "minimum": 1} 51 | } 52 | }, 53 | "data": { 54 | "type": "object", 55 | "required": ["strategy"], 56 | "properties": { 57 | "strategy": {"enum": ["toolbench", "teacher_mode", "manual_templates"]}, 58 | "max_samples": {"type": "integer", "minimum": 1}, 59 | "train_split": {"type": "number", "minimum": 0, "maximum": 1} 60 | } 61 | }, 62 | "tools": { 63 | "type": "array", 64 | "minItems": 1, 65 | "items": { 66 | "type": "object", 67 | "required": ["name", "description", "parameters"], 68 | "properties": { 69 | "name": {"type": "string"}, 70 | "description": {"type": "string"}, 71 | "parameters": {"type": "object"} 72 | } 73 | } 74 | } 75 | } 76 | } 77 | 78 | try: 79 | jsonschema.validate(self.config, schema) 80 | except jsonschema.ValidationError as e: 81 | raise ValueError(f"Invalid configuration: {e.message}") 82 | 83 | def get_config(self) -> Dict[str, Any]: 84 | """Return the loaded configuration.""" 85 | return self.config 86 | 87 | def get_section(self, section: str) -> Dict[str, Any]: 88 | """Get a specific configuration section.""" 89 | return self.config.get(section, {}) 90 | -------------------------------------------------------------------------------- /Tiny Tool Use/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be added to the global gitignore or merged into this project gitignore. For a PyCharm 158 | # project, this template is based on a Python project template available from: 159 | # https://github.com/github/gitignore/blob/main/Python.gitignore 160 | .idea/ 161 | 162 | # TensorBoard logs 163 | runs/ 164 | logs/ 165 | 166 | # Model outputs and training artifacts 167 | outputs/checkpoints/ 168 | outputs/models/ 169 | outputs/logs/ -------------------------------------------------------------------------------- /Tiny Tool Use/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | LLM Tool Use Training Playground 4 | Main training script supporting multiple training recipes and data sources. 5 | """ 6 | 7 | import argparse 8 | import json 9 | import logging 10 | from pathlib import Path 11 | 12 | import torch 13 | from rich.console import Console 14 | 15 | from src.data.data_generator import DataGenerator 16 | from src.training.trainer import ToolTrainer 17 | from src.utils.config import ConfigManager 18 | from src.utils.logging_utils import setup_logging 19 | 20 | 21 | console = Console() 22 | 23 | 24 | def main(): 25 | parser = argparse.ArgumentParser(description="Train LLM for tool use") 26 | parser.add_argument( 27 | "--config", 28 | type=str, 29 | required=True, 30 | help="Path to configuration file" 31 | ) 32 | parser.add_argument( 33 | "--output-dir", 34 | type=str, 35 | default="./outputs", 36 | help="Output directory for models and logs" 37 | ) 38 | parser.add_argument( 39 | "--resume", 40 | type=str, 41 | default=None, 42 | help="Resume training from checkpoint" 43 | ) 44 | parser.add_argument( 45 | "--debug", 46 | action="store_true", 47 | help="Enable debug logging" 48 | ) 49 | 50 | args = parser.parse_args() 51 | 52 | # Setup logging 53 | setup_logging(debug=args.debug) 54 | logger = logging.getLogger(__name__) 55 | 56 | # Load configuration 57 | config_manager = ConfigManager(args.config) 58 | config = config_manager.get_config() 59 | 60 | logger.info(f"Loaded configuration from {args.config}") 61 | logger.info(f"Training method: {config['training']['method']}") 62 | logger.info(f"Data strategy: {config['data']['strategy']}") 63 | 64 | # TensorBoard logging is handled by the trainer itself 65 | # No additional initialization needed here 66 | 67 | # Set random seeds for reproducibility 68 | torch.manual_seed(config.get("seed", 42)) 69 | if torch.cuda.is_available(): 70 | torch.cuda.manual_seed_all(config.get("seed", 42)) 71 | 72 | # Create output directory 73 | output_dir = Path(args.output_dir) 74 | output_dir.mkdir(parents=True, exist_ok=True) 75 | 76 | # Save config to output directory 77 | config_path = output_dir / "config.json" 78 | with open(config_path, "w") as f: 79 | json.dump(config, f, indent=2) 80 | 81 | try: 82 | # Initialize data generator 83 | console.print("🔄 [bold blue]Initializing data generator...[/bold blue]") 84 | data_generator = DataGenerator(config["data"], config["tools"], config["tokenizer"]) 85 | 86 | # Generate or load training data 87 | console.print("📊 [bold blue]Preparing training data...[/bold blue]") 88 | train_dataset, eval_dataset = data_generator.prepare_datasets() 89 | 90 | logger.info(f"Training samples: {len(train_dataset)}") 91 | logger.info(f"Evaluation samples: {len(eval_dataset)}") 92 | 93 | # Initialize trainer 94 | console.print("🚀 [bold blue]Initializing trainer...[/bold blue]") 95 | trainer = ToolTrainer( 96 | config=config, 97 | train_dataset=train_dataset, 98 | eval_dataset=eval_dataset, 99 | output_dir=output_dir 100 | ) 101 | 102 | # Start training 103 | console.print("🎯 [bold green]Starting training...[/bold green]") 104 | trainer.train(resume_from_checkpoint=args.resume) 105 | 106 | console.print("✅ [bold green]Training completed successfully![/bold green]") 107 | 108 | # Clean up resources 109 | trainer.cleanup() 110 | 111 | except Exception as e: 112 | logger.error(f"Training failed: {str(e)}") 113 | console.print(f"❌ [bold red]Training failed: {str(e)}[/bold red]") 114 | # Clean up even on failure 115 | if 'trainer' in locals(): 116 | trainer.cleanup() 117 | raise 118 | 119 | # No cleanup needed for TensorBoard 120 | # The trainer handles the SummaryWriter lifecycle 121 | 122 | 123 | if __name__ == "__main__": 124 | main() 125 | -------------------------------------------------------------------------------- /Tiny Tool Use/src/tools/executor.py: -------------------------------------------------------------------------------- 1 | """Tool management and execution utilities.""" 2 | 3 | import logging 4 | import subprocess 5 | from typing import Dict, Any, List, Optional 6 | import requests 7 | 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | class ToolExecutor: 13 | """Handles execution of external tools.""" 14 | 15 | def __init__(self, tools_config: List[Dict[str, Any]]): 16 | self.tools = {tool["name"]: tool for tool in tools_config} 17 | logger.info(f"Initialized {len(self.tools)} tools: {list(self.tools.keys())}") 18 | 19 | def execute_tool(self, tool_name: str, parameters: Dict[str, Any]) -> Dict[str, Any]: 20 | """Execute a tool with given parameters.""" 21 | if tool_name not in self.tools: 22 | return {"error": f"Tool '{tool_name}' not found"} 23 | 24 | tool = self.tools[tool_name] 25 | 26 | try: 27 | # Handle different tool types 28 | if tool.get("type") == "api": 29 | return self._execute_api_tool(tool, parameters) 30 | elif tool.get("type") == "function": 31 | return self._execute_function_tool(tool, parameters) 32 | elif tool.get("type") == "command": 33 | return self._execute_command_tool(tool, parameters) 34 | else: 35 | # Default to function execution 36 | return self._execute_function_tool(tool, parameters) 37 | 38 | except Exception as e: 39 | logger.error(f"Error executing tool {tool_name}: {str(e)}") 40 | return {"error": str(e)} 41 | 42 | def _execute_api_tool(self, tool: Dict[str, Any], parameters: Dict[str, Any]) -> Dict[str, Any]: 43 | """Execute an API-based tool.""" 44 | url = tool["url"] 45 | method = tool.get("method", "POST") 46 | headers = tool.get("headers", {}) 47 | 48 | if method.upper() == "POST": 49 | response = requests.post(url, json=parameters, headers=headers) 50 | elif method.upper() == "GET": 51 | response = requests.get(url, params=parameters, headers=headers) 52 | else: 53 | return {"error": f"Unsupported HTTP method: {method}"} 54 | 55 | response.raise_for_status() 56 | return response.json() 57 | 58 | def _execute_function_tool(self, tool: Dict[str, Any], parameters: Dict[str, Any]) -> Dict[str, Any]: 59 | """Execute a built-in function tool.""" 60 | function_name = tool.get("function", tool["name"]) 61 | 62 | # Built-in calculator example 63 | if function_name == "calculator": 64 | return self._calculator(parameters) 65 | elif function_name == "weather": 66 | return self._weather_mock(parameters) 67 | elif function_name == "search": 68 | return self._search_mock(parameters) 69 | else: 70 | return {"error": f"Function '{function_name}' not implemented"} 71 | 72 | def _execute_command_tool(self, tool: Dict[str, Any], parameters: Dict[str, Any]) -> Dict[str, Any]: 73 | """Execute a command-line tool.""" 74 | command = tool["command"].format(**parameters) 75 | 76 | try: 77 | result = subprocess.run( 78 | command, 79 | shell=True, 80 | capture_output=True, 81 | text=True, 82 | timeout=30 83 | ) 84 | 85 | return { 86 | "stdout": result.stdout, 87 | "stderr": result.stderr, 88 | "returncode": result.returncode 89 | } 90 | except subprocess.TimeoutExpired: 91 | return {"error": "Command timed out"} 92 | 93 | def _calculator(self, parameters: Dict[str, Any]) -> Dict[str, Any]: 94 | """Simple calculator implementation.""" 95 | expression = parameters.get("expression", "") 96 | 97 | try: 98 | # Safe evaluation of mathematical expressions 99 | allowed_chars = set("0123456789+-*/.() ") 100 | if not all(c in allowed_chars for c in expression): 101 | return {"error": "Invalid characters in expression"} 102 | 103 | result = eval(expression) 104 | return {"result": result} 105 | except Exception as e: 106 | return {"error": f"Calculation error: {str(e)}"} 107 | 108 | def _weather_mock(self, parameters: Dict[str, Any]) -> Dict[str, Any]: 109 | """Mock weather API.""" 110 | location = parameters.get("location", "Unknown") 111 | return { 112 | "location": location, 113 | "temperature": "22°C", 114 | "condition": "Sunny", 115 | "humidity": "65%" 116 | } 117 | 118 | def _search_mock(self, parameters: Dict[str, Any]) -> Dict[str, Any]: 119 | """Mock search API.""" 120 | query = parameters.get("query", "") 121 | return { 122 | "query": query, 123 | "results": [ 124 | {"title": f"Result 1 for {query}", "url": "https://example.com/1"}, 125 | {"title": f"Result 2 for {query}", "url": "https://example.com/2"} 126 | ] 127 | } 128 | 129 | def get_tool_schema(self, tool_name: str) -> Optional[Dict[str, Any]]: 130 | """Get the schema for a specific tool.""" 131 | if tool_name in self.tools: 132 | return self.tools[tool_name] 133 | return None 134 | 135 | def list_tools(self) -> List[str]: 136 | """List all available tools.""" 137 | return list(self.tools.keys()) 138 | -------------------------------------------------------------------------------- /Tiny Tool Use/README.md: -------------------------------------------------------------------------------- 1 | # Tiny Tool Use 2 | 3 | An intentionally-tiny yet production-ready open-source library for fine-tuning Large Language Models (LLMs) to make robust, auditable tool calls. 4 | 5 | ## Features 6 | 7 | - **Configuration-only workflows** – every experiment, tool schema, and hyper-parameter lives in a JSON file so results travel cleanly between repos. 8 | - **Interchangeable optimisers** – swap Supervised Fine-Tuning, Direct Preference Optimisation (DPO), or synthetic teacher signals with a single config flag. 9 | - **First-class evaluation support** – TensorBoard dashboards and ready-made Berkeley Function Calling Leaderboard scripts. 10 | - **Dataset flexibility** – plug in real data, generate synthetic traces, or compose both without touching core code. 11 | 12 | ## Quick Start 13 | 14 | 1. **Setup Environment**: 15 | ```bash 16 | chmod +x setup.sh 17 | ./setup.sh 18 | source venv/bin/activate 19 | ``` 20 | 21 | 2. **Run Basic Training**: 22 | ```bash 23 | # Supervised fine-tuning with manual templates 24 | python train.py --config configs/sft_toolbench_config.json --outdir outputs/toolbench_results 25 | 26 | # DPO training with manual templates 27 | python train.py --config configs/dpo_config.json --outdir outputs/dpo_results 28 | ``` 29 | 30 | 3. **Merging LORA Adapters** 31 | 32 | If you are using lora adapters to finetune the model, you can merge the lora adapters once your training finishes, using the followig 33 | 34 | ```bash 35 | python save_merge_model.py --base_model BASE-MODEL-NAME --adapter_path PATH/TO/SAVED/ADAPTER --output_dir PATH/TO/MERGED/MODEL 36 | ``` 37 | 38 | 39 | 4. **Evaluate Trained Model**: 40 | ```bash 41 | python evaluate.py --model-path PATH/TO/FINAL/SAVED/MODEL --config PATH/TO/TRAINING/CONFIGURATION/JSON/FILE 42 | ``` 43 | 44 | 5. **Interactive Examples**: 45 | ```bash 46 | python examples/run_examples.py 47 | ``` 48 | 49 | ## Project Structure 50 | 51 | ``` 52 | ├── configs/ # Configuration files 53 | │ ├── sft_toolbench_config.json # SFT example with real toolbench data 54 | │ └── dpo_config.json # DPO example 55 | ├── src/ # Source code 56 | │ ├── data/ # Data generation and loading 57 | │ ├── models/ # Model definitions 58 | │ ├── training/ # Training loops and algorithms 59 | │ ├── tools/ # Tool handling 60 | │ └── utils/ # Utilities and evaluation 61 | ├── examples/ # Example configurations and tools 62 | │ ├── run_examples.py # Interactive examples 63 | │ └── test_cases.json # Standard test cases 64 | ├── tests/ # Unit tests 65 | ├── train.py # Main training script 66 | ├── evaluate.py # Evaluation script 67 | └── setup.sh # Setup script 68 | ``` 69 | 70 | ## Configuration 71 | 72 | The library uses JSON configuration files to define: 73 | 74 | ### Model Configuration 75 | ```json 76 | { 77 | "model": { 78 | "name": "Qwen/Qwen3-0.6B", 79 | "trust_remote_code": false, 80 | "torch_dtype": "float16", 81 | "device_map": "auto" 82 | } 83 | } 84 | ``` 85 | 86 | ### Training Configuration 87 | ```json 88 | { 89 | "training": { 90 | "method": "sft", // "sft", "dpo", "teacher_mode" 91 | "num_epochs": 3, 92 | "learning_rate": 5e-5, 93 | "batch_size": 4, 94 | "use_lora": true 95 | } 96 | } 97 | ``` 98 | 99 | ### Tool Definitions 100 | If you want to train your model on custom tools with sythetic data generated for custom tools, you can 101 | define the tools as well as the dataset. 102 | ```json 103 | { 104 | "tools": [ 105 | { 106 | "name": "calculator", 107 | "description": "Perform mathematical calculations", 108 | "type": "function", 109 | "parameters": { 110 | "type": "object", 111 | "properties": { 112 | "expression": { 113 | "type": "string", 114 | "description": "Mathematical expression to evaluate" 115 | } 116 | }, 117 | "required": ["expression"] 118 | } 119 | } 120 | ] 121 | } 122 | ``` 123 | 124 | ## Supported Training Methods 125 | 126 | 1. **Supervised Fine-tuning (SFT)**: Standard next-token prediction on tool-augmented conversations 127 | 2. **DPO**: Direct Preference Optimization using preference pairs 128 | 3. **Teacher Mode**: Self-supervised data generation (Toolformer-style) with synthetic data 129 | 130 | ## Data Generation Strategies 131 | 132 | 1. **ToolBench**: Use both real and synthetic tool bench datasets 133 | 2. **Teacher Mode**: LLM generates its own tool-augmented examples 134 | 3. **Manual Templates**: Bootstrap from canonical examples with paraphrasing 135 | 136 | ## Built-in Tools 137 | 138 | The library includes several built-in tools for testing: 139 | 140 | - **Calculator**: Basic arithmetic operations 141 | - **Weather**: Mock weather API 142 | - **Search**: Mock search functionality 143 | 144 | You can easily add custom tools by extending the `ToolExecutor` class. 145 | 146 | ## Training Examples 147 | 148 | ### Example 1: Calculator with SFT 149 | ```bash 150 | python train.py --config configs/sft_toolbench_config.json --output-dir outputs/sft_toolbench 151 | ``` 152 | 153 | 154 | ## Evaluation 155 | 156 | There are two evaluation criteria for the library. 157 | 1. Berkeley Function Calling Leaderboard evaluation 158 | 159 | 2. Other comprehensive evaluation metrics: 160 | 161 | - **Tool Accuracy**: Correct tool selection 162 | - **Format Correctness**: Proper tool call formatting 163 | - **Execution Success**: Successful tool execution 164 | - **Response Quality**: Overall response quality 165 | 166 | 167 | 1. To run the evaluation of a trained model on Berkeley Function Calling Leaderboard (BFCL) use the following instruction. 168 | 169 | ```bash 170 | # In your shell environment 171 | export BFCL_PROJECT_ROOT=/path/to/your/desired/project/directory 172 | ``` 173 | 174 | 175 | Run evaluation on BFCL using finetuned `Qwen3-0.6B` model 176 | ```bash 177 | bfcl generate --model Qwen/Qwen3-0.6B-FC --local-model-path PATH_TO_FINETUNED_MODEL --test-category simple,parallel,multiple,multuturn 178 | ``` 179 | This will create a directory `result/` and the generated `json` files within this directory. 180 | Once the model response are generated with BFCL run the following command to evaluate the performance of the trained model 181 | 182 | ```bash 183 | bfcl evaluate --model Qwen/Qwen3-0.6B-FC --test-category simple,parallel,multiple,multuturn 184 | ``` 185 | 186 | 2. 187 | Run comprehensive evaluation with `Qwen3-0.6B` finetuned using DPO: 188 | ```bash 189 | python evaluate.py --model-path PATH/TO/FINETUNED/MODEL --config PATH/TO/CONFIG/FILE 190 | ``` 191 | 192 | For example if training is performed with `dpo_config.json`, then for comprehensive evaluation 193 | 194 | ```bash 195 | python evaluate.py --model-path PATH/TO/FINETUNED/MODEL --config dpo_config.json 196 | ``` 197 | 198 | 199 | ## Customization 200 | 201 | ### Adding New Tools 202 | 203 | 1. Define tool schema in configuration: 204 | ```json 205 | { 206 | "name": "my_tool", 207 | "description": "My custom tool", 208 | "type": "function", 209 | "parameters": {...} 210 | } 211 | ``` 212 | 213 | 2. Implement tool function in `src/tools/executor.py`: 214 | ```python 215 | def _my_tool(self, parameters): 216 | # Implementation here 217 | return {"result": "success"} 218 | ``` 219 | 220 | ### Custom Training Methods 221 | 222 | Extend the `ToolTrainer` class in `src/training/trainer.py` to add new training algorithms. 223 | 224 | ### Custom Data Sources 225 | 226 | Implement new data generation strategies in `src/data/data_generator.py`. 227 | 228 | ## Potential Applications 229 | 230 | This library can serve as a foundation for a variety of tool-use research problems and production scenarios, including: 231 | 232 | - **Robotics control** – grounding language instructions into low-level robot actions through tool calls. 233 | - **Autonomous agents** – building multi-step assistants that plan, call, and combine external APIs. 234 | - **Workflow automation** – integrating structured tool calls into data-engineering or MLOps pipelines. 235 | - **Information retrieval** – augmenting LLM responses with live search or specialized knowledge bases. 236 | - **Education & tutoring systems** – teaching models to execute calculators, solvers, or simulators on demand. 237 | 238 | ## Contributing 239 | 240 | Contributions are welcome! Please see the issues page for areas where help is needed. 241 | 242 | ## License 243 | 244 | MIT License - see LICENSE file for details. 245 | -------------------------------------------------------------------------------- /Tiny Tool Use/demo.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Demo script showcasing all features of the LLM Tool Training Playground. 4 | """ 5 | 6 | import json 7 | import subprocess 8 | import sys 9 | from pathlib import Path 10 | 11 | 12 | def print_header(title: str): 13 | """Print a formatted header.""" 14 | print(f"\n{'='*60}") 15 | print(f"🎯 {title}") 16 | print(f"{'='*60}") 17 | 18 | 19 | def print_step(step_num: int, description: str): 20 | """Print a formatted step.""" 21 | print(f"\n📌 Step {step_num}: {description}") 22 | print("-" * 40) 23 | 24 | 25 | def demonstrate_tool_execution(): 26 | """Demonstrate basic tool execution.""" 27 | print_step(1, "Tool Execution Demo") 28 | 29 | # Add src to Python path for demo 30 | demo_script = ''' 31 | import sys 32 | sys.path.insert(0, "src") 33 | 34 | from tools.executor import ToolExecutor 35 | 36 | # Define tools 37 | tools_config = [ 38 | { 39 | "name": "calculator", 40 | "description": "Perform calculations", 41 | "type": "function", 42 | "function": "calculator", 43 | "parameters": {"type": "object", "properties": {"expression": {"type": "string"}}} 44 | }, 45 | { 46 | "name": "weather", 47 | "description": "Get weather info", 48 | "type": "function", 49 | "function": "weather", 50 | "parameters": {"type": "object", "properties": {"location": {"type": "string"}}} 51 | } 52 | ] 53 | 54 | # Initialize executor 55 | executor = ToolExecutor(tools_config) 56 | 57 | # Test calculations 58 | print("🧮 Calculator Tests:") 59 | calc_tests = ["2 + 3", "10 * 7", "(15 + 5) / 4", "2 ** 8"] 60 | for expr in calc_tests: 61 | result = executor.execute_tool("calculator", {"expression": expr}) 62 | print(f" {expr} = {result.get('result', 'Error')}") 63 | 64 | print("\\n🌤️ Weather Tests:") 65 | weather_tests = ["London", "Tokyo", "New York"] 66 | for location in weather_tests: 67 | result = executor.execute_tool("weather", {"location": location}) 68 | print(f" {location}: {result.get('temperature', 'N/A')} - {result.get('condition', 'N/A')}") 69 | ''' 70 | 71 | with open("demo_tools.py", "w") as f: 72 | f.write(demo_script) 73 | 74 | try: 75 | result = subprocess.run([sys.executable, "demo_tools.py"], 76 | capture_output=True, text=True, cwd=Path.cwd()) 77 | print(result.stdout) 78 | if result.stderr: 79 | print(f"Warnings: {result.stderr}") 80 | except Exception as e: 81 | print(f"Demo failed: {e}") 82 | finally: 83 | Path("demo_tools.py").unlink(missing_ok=True) 84 | 85 | 86 | def demonstrate_data_generation(): 87 | """Demonstrate data generation strategies.""" 88 | print_step(2, "Data Generation Demo") 89 | 90 | strategies = [ 91 | ("manual_templates", "Manual templates with paraphrasing"), 92 | ("teacher_mode", "Teacher mode (Toolformer-style)"), 93 | ("toolbench", "ToolBench-style synthetic data") 94 | ] 95 | 96 | for strategy, description in strategies: 97 | print(f"\n🔧 {description}:") 98 | 99 | demo_script = f''' 100 | import sys 101 | sys.path.insert(0, "src") 102 | 103 | from data.data_generator import DataGenerator 104 | 105 | data_config = {{ 106 | "strategy": "{strategy}", 107 | "max_samples": 5, 108 | "train_split": 0.8 109 | }} 110 | 111 | tools_config = [{{ 112 | "name": "calculator", 113 | "description": "Perform calculations", 114 | "type": "function", 115 | "function": "calculator", 116 | "parameters": {{"type": "object", "properties": {{"expression": {{"type": "string"}}}}}} 117 | }}] 118 | 119 | generator = DataGenerator(data_config, tools_config) 120 | train_dataset, eval_dataset = generator.prepare_datasets() 121 | 122 | print(f"Generated {{len(train_dataset)}} training samples") 123 | print(f"Generated {{len(eval_dataset)}} evaluation samples") 124 | 125 | if len(train_dataset) > 0: 126 | print("\\nSample training data:") 127 | sample = train_dataset[0]["text"][:200] + "..." if len(train_dataset[0]["text"]) > 200 else train_dataset[0]["text"] 128 | print(f" {{sample}}") 129 | ''' 130 | 131 | with open("demo_data.py", "w") as f: 132 | f.write(demo_script) 133 | 134 | try: 135 | result = subprocess.run([sys.executable, "demo_data.py"], 136 | capture_output=True, text=True, cwd=Path.cwd()) 137 | print(result.stdout) 138 | except Exception as e: 139 | print(f" Error: {e}") 140 | finally: 141 | Path("demo_data.py").unlink(missing_ok=True) 142 | 143 | 144 | def show_training_configurations(): 145 | """Show available training configurations.""" 146 | print_step(3, "Training Configurations") 147 | 148 | config_files = [ 149 | ("calculator_config.json", "Supervised Fine-tuning with Calculator"), 150 | ("ppo_config.json", "PPO Training with Rewards"), 151 | ("teacher_mode_config.json", "Teacher Mode Multi-tool Training") 152 | ] 153 | 154 | for config_file, description in config_files: 155 | config_path = f"configs/{config_file}" 156 | if Path(config_path).exists(): 157 | print(f"\n📋 {description}") 158 | print(f" Config: {config_file}") 159 | 160 | with open(config_path, 'r') as f: 161 | config = json.load(f) 162 | 163 | print(f" Method: {config['training']['method']}") 164 | print(f" Model: {config['model']['name']}") 165 | print(f" Data Strategy: {config['data']['strategy']}") 166 | print(f" Tools: {[tool['name'] for tool in config['tools']]}") 167 | 168 | print(f"\n 🚀 To train: python train.py --config {config_path}") 169 | print(f" 🔍 To evaluate: python evaluate.py --model-path outputs/{config_file.replace('.json', '')} --config {config_path}") 170 | 171 | 172 | def show_evaluation_example(): 173 | """Show evaluation capabilities.""" 174 | print_step(4, "Evaluation Framework") 175 | 176 | print("📊 Available Evaluation Metrics:") 177 | metrics = [ 178 | ("Tool Accuracy", "Correct tool selection rate"), 179 | ("Format Correctness", "Proper tool call formatting"), 180 | ("Execution Success", "Successful tool execution rate"), 181 | ("Response Quality", "Overall response quality score") 182 | ] 183 | 184 | for metric, description in metrics: 185 | print(f" • {metric}: {description}") 186 | 187 | print("\n📝 Standard Test Cases Available:") 188 | try: 189 | with open("examples/test_cases.json", 'r') as f: 190 | test_cases = json.load(f) 191 | 192 | print(f" • {len(test_cases)} predefined test cases") 193 | print(" • Covers: calculator, weather, search tools") 194 | 195 | print("\n🔍 Sample Test Case:") 196 | sample_case = test_cases[0] 197 | print(f" Prompt: {sample_case['prompt']}") 198 | print(f" Expected Tool: {sample_case['expected_tool']}") 199 | print(f" Expected Params: {sample_case['expected_params']}") 200 | 201 | except Exception as e: 202 | print(f" Error loading test cases: {e}") 203 | 204 | 205 | def show_next_steps(): 206 | """Show next steps for users.""" 207 | print_step(5, "Getting Started") 208 | 209 | steps = [ 210 | "🔧 Setup Environment", 211 | " ./setup.sh", 212 | "", 213 | "🎯 Run Quick Training Example", 214 | " python train.py --config configs/calculator_config.json --output-dir outputs/demo", 215 | "", 216 | "🔍 Evaluate Results", 217 | " python evaluate.py --model-path outputs/demo --config configs/calculator_config.json", 218 | "", 219 | "🎮 Try Interactive Examples", 220 | " python examples/run_examples.py", 221 | "", 222 | "🧪 Run Tests", 223 | " python validate.py", 224 | "", 225 | "📚 Customize Your Tools", 226 | " Edit configs/your_config.json", 227 | " Add tools in src/tools/executor.py" 228 | ] 229 | 230 | for step in steps: 231 | print(step) 232 | 233 | 234 | def main(): 235 | """Main demo function.""" 236 | print_header("LLM Tool Use Training Playground - Complete Demo") 237 | 238 | print("🎉 Welcome to the LLM Tool Use Training Playground!") 239 | print("This demo showcases all the key features of the framework.") 240 | 241 | # Run demonstrations 242 | demonstrate_tool_execution() 243 | demonstrate_data_generation() 244 | show_training_configurations() 245 | show_evaluation_example() 246 | show_next_steps() 247 | 248 | print_header("Demo Complete") 249 | print("✅ The playground is ready for training LLMs on tool use!") 250 | print("📖 See README.md for detailed documentation.") 251 | print("🐛 Found issues? Check validate.py or create an issue.") 252 | 253 | 254 | if __name__ == "__main__": 255 | main() 256 | -------------------------------------------------------------------------------- /Tiny Tool Use/src/utils/evaluation.py: -------------------------------------------------------------------------------- 1 | """Evaluation utilities for tool use models.""" 2 | 3 | import json 4 | import logging 5 | import re 6 | from typing import Dict, Any, List 7 | import torch 8 | from transformers import AutoTokenizer, AutoModelForCausalLM 9 | from datasets import Dataset 10 | 11 | from ..tools.executor import ToolExecutor 12 | 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | class ToolUseEvaluator: 18 | """Evaluates model performance on tool use tasks.""" 19 | 20 | def __init__(self, model_path: str, tools_config: List[Dict[str, Any]]): 21 | self.model_path = model_path 22 | self.tools_config = tools_config 23 | self.tool_executor = ToolExecutor(tools_config) 24 | 25 | # Load model and tokenizer 26 | self.tokenizer = AutoTokenizer.from_pretrained(model_path) 27 | self.model = AutoModelForCausalLM.from_pretrained( 28 | model_path, 29 | torch_dtype=torch.float16, 30 | device_map="auto" 31 | ) 32 | 33 | if self.tokenizer.pad_token is None: 34 | self.tokenizer.pad_token = self.tokenizer.eos_token 35 | 36 | def evaluate_dataset(self, eval_dataset: Dataset) -> Dict[str, float]: 37 | """Evaluate model on a dataset.""" 38 | results = { 39 | "tool_accuracy": 0.0, 40 | "format_correctness": 0.0, 41 | "execution_success": 0.0, 42 | "response_quality": 0.0 43 | } 44 | 45 | total_samples = len(eval_dataset) 46 | correct_tools = 0 47 | correct_format = 0 48 | successful_executions = 0 49 | 50 | for example in eval_dataset: 51 | # Generate response 52 | response = self._generate_response(example["text"]) 53 | 54 | # Extract ground truth tool 55 | gt_tool = example.get("tool_name", "") 56 | 57 | # Evaluate tool accuracy 58 | predicted_tool = self._extract_tool_name(response) 59 | if predicted_tool == gt_tool: 60 | correct_tools += 1 61 | 62 | # Evaluate format correctness 63 | if self._check_format_correctness(response): 64 | correct_format += 1 65 | 66 | # Evaluate execution success 67 | if self._check_execution_success(response): 68 | successful_executions += 1 69 | 70 | results["tool_accuracy"] = correct_tools / total_samples 71 | results["format_correctness"] = correct_format / total_samples 72 | results["execution_success"] = successful_executions / total_samples 73 | results["response_quality"] = (results["tool_accuracy"] + results["format_correctness"]) / 2 74 | 75 | return results 76 | 77 | def _generate_response(self, prompt: str) -> str: 78 | """Generate response for a given prompt.""" 79 | # Extract just the human part for generation 80 | if "Human:" in prompt and "Assistant:" in prompt: 81 | human_part = prompt.split("Assistant:")[0] + "Assistant:" 82 | else: 83 | human_part = prompt 84 | 85 | # Encode input and move to the same device as the model 86 | inputs = self.tokenizer.encode(human_part, return_tensors="pt") 87 | device = self.model.device # Get the model's device 88 | inputs = inputs.to(device) # Move input tensor to the same device 89 | 90 | with torch.no_grad(): 91 | outputs = self.model.generate( 92 | inputs, 93 | max_length=inputs.shape[1] + 256, 94 | num_return_sequences=1, 95 | temperature=0.7, 96 | do_sample=True, 97 | pad_token_id=self.tokenizer.pad_token_id, 98 | eos_token_id=self.tokenizer.eos_token_id 99 | ) 100 | 101 | 102 | response = self.tokenizer.decode(outputs[0], skip_special_tokens=False) 103 | 104 | # Extract just the generated part 105 | if human_part in response: 106 | response = response.replace(human_part, "").strip() 107 | 108 | return response 109 | 110 | def _extract_tool_name(self, response: str) -> str: 111 | """Extract tool name from response.""" 112 | # Look for tool call pattern 113 | tool_call_pattern = r'\[TOOL_CALL\](.*?)\[/TOOL_CALL\]' 114 | matches = re.findall(tool_call_pattern, response, re.DOTALL) 115 | 116 | if matches: 117 | try: 118 | tool_data = json.loads(matches[0]) 119 | return tool_data.get("name", "") 120 | except json.JSONDecodeError: 121 | pass 122 | 123 | # Fallback: look for tool names in text 124 | for tool in self.tools_config: 125 | if tool["name"].lower() in response.lower(): 126 | return tool["name"] 127 | 128 | return "" 129 | 130 | def _check_format_correctness(self, response: str) -> bool: 131 | """Check if response has correct tool call format.""" 132 | # Check for proper tool call tags 133 | if "[TOOL_CALL]" not in response or "[/TOOL_CALL]" not in response: 134 | return False 135 | 136 | # Extract tool call content 137 | tool_call_pattern = r'\[TOOL_CALL\](.*?)\[/TOOL_CALL\]' 138 | matches = re.findall(tool_call_pattern, response, re.DOTALL) 139 | 140 | if not matches: 141 | return False 142 | 143 | try: 144 | # Try to parse as JSON 145 | tool_data = json.loads(matches[0]) 146 | 147 | # Check required fields 148 | if "name" not in tool_data or "parameters" not in tool_data: 149 | return False 150 | 151 | # Check if tool name is valid 152 | tool_names = [tool["name"] for tool in self.tools_config] 153 | if tool_data["name"] not in tool_names: 154 | return False 155 | 156 | return True 157 | 158 | except json.JSONDecodeError: 159 | return False 160 | 161 | def _check_execution_success(self, response: str) -> bool: 162 | """Check if tool execution would be successful.""" 163 | if not self._check_format_correctness(response): 164 | return False 165 | 166 | # Extract and execute tool call 167 | tool_call_pattern = r'\[TOOL_CALL\](.*?)\[/TOOL_CALL\]' 168 | matches = re.findall(tool_call_pattern, response, re.DOTALL) 169 | 170 | if matches: 171 | try: 172 | tool_data = json.loads(matches[0]) 173 | tool_name = tool_data["name"] 174 | parameters = tool_data["parameters"] 175 | 176 | # Execute tool 177 | result = self.tool_executor.execute_tool(tool_name, parameters) 178 | 179 | # Check if execution was successful (no error) 180 | return "error" not in result 181 | 182 | except Exception: 183 | return False 184 | 185 | return False 186 | 187 | def evaluate_specific_cases(self, test_cases: List[Dict[str, Any]]) -> Dict[str, Any]: 188 | """Evaluate on specific test cases.""" 189 | results = [] 190 | 191 | for case in test_cases: 192 | prompt = case["prompt"] 193 | expected_tool = case["expected_tool"] 194 | expected_params = case.get("expected_params", {}) 195 | 196 | response = self._generate_response(prompt) 197 | predicted_tool = self._extract_tool_name(response) 198 | 199 | result = { 200 | "prompt": prompt, 201 | "response": response, 202 | "expected_tool": expected_tool, 203 | "predicted_tool": predicted_tool, 204 | "tool_correct": predicted_tool == expected_tool, 205 | "format_correct": self._check_format_correctness(response), 206 | "execution_success": self._check_execution_success(response) 207 | } 208 | 209 | results.append(result) 210 | 211 | # Calculate summary statistics 212 | summary = { 213 | "total_cases": len(results), 214 | "tool_accuracy": sum(r["tool_correct"] for r in results) / len(results), 215 | "format_accuracy": sum(r["format_correct"] for r in results) / len(results), 216 | "execution_accuracy": sum(r["execution_success"] for r in results) / len(results), 217 | "details": results 218 | } 219 | 220 | return summary 221 | 222 | 223 | def create_test_cases() -> List[Dict[str, Any]]: 224 | """Create standard test cases for evaluation.""" 225 | return [ 226 | { 227 | "prompt": "Human: What is 25 * 4?", 228 | "expected_tool": "calculator", 229 | "expected_params": {"expression": "25 * 4"} 230 | }, 231 | { 232 | "prompt": "Human: What's the weather like in Paris?", 233 | "expected_tool": "weather", 234 | "expected_params": {"location": "Paris"} 235 | }, 236 | { 237 | "prompt": "Human: Search for information about machine learning", 238 | "expected_tool": "search", 239 | "expected_params": {"query": "machine learning"} 240 | }, 241 | { 242 | "prompt": "Human: Calculate 100 divided by 5", 243 | "expected_tool": "calculator", 244 | "expected_params": {"expression": "100 / 5"} 245 | }, 246 | { 247 | "prompt": "Human: How's the weather in Tokyo today?", 248 | "expected_tool": "weather", 249 | "expected_params": {"location": "Tokyo"} 250 | } 251 | ] 252 | -------------------------------------------------------------------------------- /Tiny Tool Use/src/training/trainer.py: -------------------------------------------------------------------------------- 1 | """Training module for tool use models.""" 2 | # PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True 3 | import logging 4 | from pathlib import Path 5 | from typing import Dict, Any, Optional 6 | 7 | 8 | import torch 9 | from transformers import ( 10 | AutoTokenizer, 11 | AutoModelForCausalLM 12 | ) 13 | from datasets import Dataset 14 | from peft import LoraConfig, get_peft_model, TaskType 15 | from trl import DPOTrainer, SFTTrainer 16 | from trl import DPOConfig, SFTConfig 17 | from torch.utils.tensorboard import SummaryWriter 18 | from transformers import BitsAndBytesConfig 19 | from peft import prepare_model_for_kbit_training 20 | 21 | from ..tools.executor import ToolExecutor 22 | 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | 27 | class ToolTrainer: 28 | """Main trainer class for tool use models.""" 29 | 30 | def __init__( 31 | self, 32 | config: Dict[str, Any], 33 | train_dataset: Dataset, 34 | eval_dataset: Dataset, 35 | output_dir: Path 36 | ): 37 | self.config = config 38 | self.train_dataset = train_dataset 39 | self.eval_dataset = eval_dataset 40 | self.output_dir = output_dir 41 | 42 | # Initialize TensorBoard logging 43 | self.writer = None 44 | if self.config.get("tensorboard", {}).get("enabled", False): 45 | log_dir = self.config.get("tensorboard", {}).get("log_dir", str(output_dir / "runs")) 46 | self.writer = SummaryWriter(log_dir=log_dir) 47 | 48 | # Initialize model and tokenizer 49 | self.tokenizer = self._load_tokenizer() 50 | self.model = self._load_model() 51 | 52 | 53 | 54 | # Initialize tool executor for evaluation 55 | self.tool_executor = ToolExecutor(config["tools"]) 56 | 57 | # Training method 58 | self.training_method = config["training"]["method"] 59 | 60 | logger.info(f"Initialized trainer with method: {self.training_method}") 61 | 62 | def _load_tokenizer(self) -> AutoTokenizer: 63 | """Load tokenizer.""" 64 | 65 | model_name = self.config["model"]["name"] 66 | 67 | if "qwen3" in model_name.lower() and "toolbench" in self.config["data"]["strategy"]: 68 | 69 | 70 | tokenizer = AutoTokenizer.from_pretrained( 71 | model_name, 72 | trust_remote_code=self.config["model"].get("trust_remote_code", False), 73 | 74 | ) 75 | 76 | return tokenizer 77 | 78 | 79 | else: 80 | 81 | 82 | model_name = self.config["model"]["name"] 83 | 84 | tokenizer = AutoTokenizer.from_pretrained( 85 | model_name, 86 | trust_remote_code=self.config["model"].get("trust_remote_code", False), 87 | 88 | ) 89 | 90 | 91 | 92 | # Add special tokens for tool calls 93 | special_tokens = { 94 | "additional_special_tokens": [ 95 | "[TOOL_CALL]", "[/TOOL_CALL]", 96 | "[RESULT]", "[/RESULT]" 97 | ] 98 | } 99 | 100 | tokenizer.add_special_tokens(special_tokens) 101 | 102 | #Set pad token if not exists 103 | if tokenizer.pad_token is None: 104 | tokenizer.pad_token = tokenizer.eos_token 105 | 106 | return tokenizer 107 | 108 | def _load_model(self) -> AutoModelForCausalLM: 109 | """Load and prepare model.""" 110 | model_config = self.config["model"] 111 | 112 | 113 | if self.config["training"].get("use_lora",True): 114 | 115 | #bits and bytes configuration 116 | bnb_config = BitsAndBytesConfig( 117 | load_in_4bit=True, 118 | bnb_4bit_quant_type="nf4", 119 | bnb_4bit_use_double_quant=True, 120 | bnb_4bit_compute_dtype="bfloat16" 121 | ) 122 | 123 | # Load base model 124 | model = AutoModelForCausalLM.from_pretrained( 125 | model_config["name"], 126 | trust_remote_code=model_config.get("trust_remote_code", False), 127 | torch_dtype=getattr(torch, model_config.get("torch_dtype", "float16")), 128 | device_map=model_config.get("device_map", "auto"), 129 | quantization_config=bnb_config 130 | 131 | ) 132 | 133 | 134 | model = prepare_model_for_kbit_training(model) 135 | # Resize embeddings for new tokens 136 | #model.resize_token_embeddings(len(self.tokenizer)) 137 | 138 | lora_config = LoraConfig( 139 | task_type=TaskType.CAUSAL_LM, 140 | inference_mode=False, 141 | r=self.config["training"].get("lora_r", 16), 142 | lora_alpha=self.config["training"].get("lora_alpha", 32), 143 | lora_dropout=self.config["training"].get("lora_dropout", 0.1), 144 | target_modules=["q_proj", "v_proj", "k_proj", "o_proj"] 145 | ) 146 | model = get_peft_model(model, lora_config) 147 | model.print_trainable_parameters() 148 | 149 | else: 150 | 151 | # Load base model 152 | model = AutoModelForCausalLM.from_pretrained( 153 | model_config["name"], 154 | trust_remote_code=model_config.get("trust_remote_code", False), 155 | torch_dtype=getattr(torch, model_config.get("torch_dtype", "float16")), 156 | device_map=model_config.get("device_map", "auto"), 157 | ) 158 | 159 | # Resize embeddings for new tokens 160 | model.resize_token_embeddings(len(self.tokenizer)) 161 | 162 | return model 163 | 164 | def train(self, resume_from_checkpoint: Optional[str] = None): 165 | """Train the model based on the specified method.""" 166 | if self.training_method == "sft": 167 | self._train_sft(resume_from_checkpoint) 168 | elif self.training_method == "dpo": 169 | self._train_dpo(resume_from_checkpoint) 170 | elif self.training_method == "teacher_mode": 171 | self._train_teacher_mode(resume_from_checkpoint) 172 | else: 173 | raise ValueError(f"Unknown training method: {self.training_method}. Supported methods: sft, dpo, teacher_mode") 174 | 175 | def _train_sft(self, resume_from_checkpoint: Optional[str] = None): 176 | """Supervised fine-tuning.""" 177 | logger.info("Starting supervised fine-tuning...") 178 | 179 | 180 | # Training arguments 181 | training_args = SFTConfig( 182 | output_dir=str(self.output_dir), 183 | overwrite_output_dir=True, 184 | num_train_epochs=self.config["training"].get("num_epochs", 3), 185 | per_device_train_batch_size=self.config["training"].get("batch_size", 4), 186 | per_device_eval_batch_size=self.config["training"].get("eval_batch_size", 4), 187 | gradient_accumulation_steps=self.config["training"].get("gradient_accumulation_steps", 1), 188 | learning_rate=self.config["training"].get("learning_rate", 5e-5), 189 | warmup_steps=self.config["training"].get("warmup_steps", 100), 190 | logging_steps=10, 191 | eval_strategy="steps", 192 | eval_steps=100, 193 | save_strategy="steps", 194 | save_steps=500, 195 | save_total_limit=3, 196 | load_best_model_at_end=True, 197 | metric_for_best_model="eval_loss", 198 | greater_is_better=False, 199 | report_to="tensorboard" if self.config.get("tensorboard", {}).get("enabled") else None, 200 | dataloader_pin_memory=False, 201 | fp16=self.config["training"].get("use_lora",True), #turn it to true if using gpu 202 | max_grad_norm=1.0, 203 | optim = "adamw_torch" , 204 | max_seq_length=self.config["training"].get("max_length",2048), 205 | label_names = ["labels"] 206 | ) 207 | 208 | 209 | 210 | 211 | 212 | trainer = SFTTrainer( 213 | model = self.model, 214 | train_dataset = self.train_dataset, 215 | eval_dataset = self.eval_dataset, 216 | args = training_args, 217 | processing_class = self.tokenizer, 218 | 219 | 220 | ) 221 | 222 | 223 | 224 | 225 | # Train 226 | trainer.train(resume_from_checkpoint=resume_from_checkpoint) 227 | 228 | # Save final model 229 | trainer.save_model() 230 | self.tokenizer.save_pretrained(self.output_dir) 231 | 232 | def _train_dpo(self, resume_from_checkpoint: Optional[str] = None): 233 | """Train the model using DPO.""" 234 | logger.info("Starting DPO training") 235 | 236 | # Make sure we're using LoRA or the training will likely fail 237 | use_lora = self.config["training"].get("use_lora", True) 238 | 239 | preference_dataset = self._create_preference_dataset() 240 | 241 | if not use_lora: 242 | logger.warning( 243 | "⚠️ You're attempting to run DPO without LoRA which may cause NaN values. " 244 | "Consider enabling LoRA with 'use_lora': true in your config." 245 | ) 246 | 247 | # Setup training arguments with gradient clipping 248 | training_args = DPOConfig( 249 | output_dir=str(self.output_dir), 250 | num_train_epochs=self.config["training"].get("num_epochs", 3), 251 | per_device_train_batch_size=self.config["training"].get("batch_size", 4), 252 | gradient_accumulation_steps=self.config["training"].get("gradient_accumulation_steps", 1), 253 | learning_rate=self.config["training"].get("learning_rate", 5e-6), 254 | max_grad_norm=self.config["training"].get("max_grad_norm", 0.3), # Add strict gradient clipping 255 | logging_steps=10, 256 | save_strategy="steps", 257 | save_steps=100, 258 | save_total_limit=3, 259 | optim=self.config["training"].get("optim", "paged_adamw_8bit"), # Use 8-bit optimizer 260 | bf16=self.config["training"].get("bf16", False), 261 | fp16=self.config["training"].get("fp16", True), # Use mixed precision 262 | max_length=self.config["training"].get("max_length", 512), 263 | remove_unused_columns=False, 264 | beta=0.1, # Lower beta to stabilize training 265 | report_to="tensorboard" if self.config.get("tensorboard", {}).get("enabled") else None, 266 | ) 267 | 268 | 269 | # Create DPO trainer with improved stability 270 | trainer = DPOTrainer( 271 | model=self.model, 272 | ref_model=None, # Use same model as reference 273 | args=training_args, 274 | train_dataset=preference_dataset, 275 | processing_class=self.tokenizer, 276 | ) 277 | 278 | # Add gradient checkpointing for memory efficiency 279 | if hasattr(self.model, "gradient_checkpointing_enable"): 280 | self.model.gradient_checkpointing_enable() 281 | 282 | # Train 283 | logger.info("Starting DPO training") 284 | trainer.train(resume_from_checkpoint=resume_from_checkpoint) 285 | 286 | # Save the trained model 287 | trainer.save_model(self.output_dir / "dpo_model") 288 | logger.info(f"Model saved to {self.output_dir / 'dpo_model'}") 289 | 290 | def _train_teacher_mode(self, resume_from_checkpoint: Optional[str] = None): 291 | """Teacher mode training (Toolformer-style).""" 292 | logger.info("Starting teacher mode training...") 293 | 294 | # This combines SFT with self-supervised learning 295 | # The data generation already handles teacher mode data creation 296 | self._train_sft(resume_from_checkpoint) 297 | 298 | def _tokenize_dataset(self, dataset: Dataset) -> Dataset: 299 | """Tokenize a dataset.""" 300 | def tokenize_function(examples): 301 | # Tokenize the text 302 | 303 | tokenized = self.tokenizer( 304 | examples["text"], 305 | truncation=True, 306 | padding=True, 307 | max_length=self.config["training"].get("max_length", 512), 308 | return_tensors="pt" 309 | ) 310 | 311 | # For causal LM, labels are the same as input_ids 312 | tokenized["labels"] = tokenized["input_ids"].clone() 313 | 314 | return tokenized 315 | 316 | return dataset.map( 317 | tokenize_function, 318 | batched=True, 319 | remove_columns=dataset.column_names 320 | ) 321 | 322 | def _create_preference_dataset(self) -> Dataset: 323 | """Create preference dataset for DPO.""" 324 | # This is a simplified implementation 325 | # In practice, you'd want human preferences or model-based ranking 326 | 327 | preference_data = [] 328 | 329 | for example in self.train_dataset.select(range(min(100, len(self.train_dataset)))): 330 | # Create a "good" and "bad" version 331 | good_response = example["text"] 332 | 333 | # Create a bad version by removing tool formatting 334 | bad_response = good_response.replace("[TOOL_CALL]", "").replace("[/TOOL_CALL]", "") 335 | 336 | preference_data.append({ 337 | "prompt": example["text"].split("Assistant:")[0] if "Assistant:" in example["text"] else "", 338 | "chosen": good_response, 339 | "rejected": bad_response 340 | }) 341 | 342 | return Dataset.from_list(preference_data) 343 | 344 | def cleanup(self): 345 | """Clean up resources.""" 346 | if self.writer is not None: 347 | self.writer.close() 348 | logger.info("TensorBoard writer closed") 349 | -------------------------------------------------------------------------------- /Tiny Tool Use/src/data/data_generator.py: -------------------------------------------------------------------------------- 1 | """Data generation for tool use training.""" 2 | 3 | import json 4 | import logging 5 | import random 6 | from typing import Dict, Any, List, Tuple, Optional 7 | from datasets import Dataset, load_dataset 8 | import pandas as pd 9 | from ..tools.executor import ToolExecutor 10 | import os 11 | from transformers import AutoTokenizer 12 | 13 | 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | class DataGenerator: 19 | """Generates training data for tool use from various sources.""" 20 | 21 | def __init__(self, data_config: Dict[str, Any], tools_config: List[Dict[str, Any]], tokenizer_config: List[Dict[str, Any]]): 22 | self.data_config = data_config 23 | self.tools_config = tools_config 24 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config["name"], trust_remote_code=tokenizer_config['trust_remote_code']) 25 | 26 | 27 | self.tool_executor = ToolExecutor(tools_config) 28 | self.strategy = data_config["strategy"] 29 | self.generation_type = data_config["generation_type"] 30 | 31 | 32 | def prepare_datasets(self) -> Tuple[Dataset, Dataset]: 33 | """Prepare training and evaluation datasets.""" 34 | if self.strategy == "toolbench" and self.generation_type.lower()=='real': 35 | return self._prepare_real_toolbench_data() 36 | elif self.strategy == "toolbench" and self.generation_type.lower()=='synthetic': 37 | return self._prepare_synthetic_toolbench_data() 38 | elif self.strategy == "teacher_mode" and self.generation_type.lower()=='synthetic': 39 | return self._prepare_teacher_mode_data() 40 | elif self.strategy == "manual_templates" and self.generation_type.lower()=='synthetic': 41 | return self._prepare_manual_template_data() 42 | else: 43 | raise ValueError(f"Unknown data strategy: {self.strategy}. Data generation strategy {self.generation_type} is not implemented for {self.strategy} ") 44 | 45 | 46 | def _download_from_google_drive(self, folder_url, destination_dir): 47 | 48 | import gdown, pathlib, zipfile 49 | 50 | destination_dir = pathlib.Path(destination_dir) 51 | 52 | files = gdown.download_folder( 53 | url = folder_url, 54 | quiet = False, 55 | use_cookies = False, 56 | output = destination_dir.as_posix() 57 | 58 | ) 59 | 60 | zip_path = next(p for p in files if p.endswith('data.zip')) 61 | 62 | print("✔ downloaded", zip_path) 63 | 64 | with zipfile.ZipFile(zip_path) as zf: 65 | zf.extractall(destination_dir.as_posix()+"/data") 66 | return 67 | 68 | 69 | 70 | 71 | 72 | def _prepare_synthetic_toolbench_data(self)->Tuple[Dataset, Dataset]: 73 | 74 | "Get the synthetic tool bench data" 75 | 76 | logger.info("Generating synthetic tool bench data...") 77 | 78 | synthetic_data = self._generate_synthetic_toolbench_data() 79 | 80 | return self._split_dataset(synthetic_data) 81 | 82 | 83 | 84 | 85 | 86 | 87 | def _prepare_real_toolbench_data(self) -> Tuple[Dataset, Dataset]: 88 | """Get toolbench data.""" 89 | logger.info("Obtaining toolbench data...") 90 | 91 | assistant_id = self.tokenizer.convert_tokens_to_ids("<|assistant|>") 92 | # synthetic_data = self._generate_synthetic_toolbench_data() 93 | 94 | #download the data from google drive link 95 | destination_dir = './data/toolbench/' 96 | if not os.path.exists(destination_dir): 97 | folder_url = 'https://drive.google.com/drive/folders/1TysbSWYpP8EioFu9xPJtpbJZMLLmwAmL' 98 | destination_dir = './data/toolbench/' 99 | self._download_from_google_drive(folder_url, destination_dir) 100 | 101 | #loading the toolbench data 102 | data = load_dataset("json", data_files="./data/toolbench/data/data/toolllama_G123_dfs_train.json")["train"] 103 | 104 | data = data.shuffle(seed=42).select(range(self.data_config["max_samples"])) 105 | 106 | def to_messages(conv): 107 | # Map any role names that can appear in ToolBench/Qwen 108 | role_map = { 109 | "system": "system", 110 | "user": "user", 111 | "assistant": "assistant", 112 | "tool": "tool", # tool_response in some repos 113 | "function": "tool", # treat function output the same as tool 114 | "tool_response": "tool", # safety net for other dumps 115 | "tool_call": "assistant", # if your dump keeps the call separate 116 | } 117 | 118 | unknown = {m["from"] for m in conv} - role_map.keys() 119 | if unknown: # fail fast if you meet something new 120 | raise ValueError(f"Unknown role(s): {unknown}") 121 | 122 | return [ 123 | {"role": role_map[m["from"]], "content": m["value"]} 124 | for m in conv 125 | ] 126 | 127 | def tokenize(sample): 128 | msgs = to_messages(sample["conversations"]) 129 | chat_text = self.tokenizer.apply_chat_template(msgs, tokenize=False, 130 | add_generation_prompt=False) # Qwen-3 Jinja template:contentReference[oaicite:1]{index=1} 131 | 132 | 133 | ids = self.tokenizer(chat_text, return_tensors="pt").input_ids[0] 134 | labels = ids.clone() 135 | 136 | # *** non-assistant masking *** 137 | ptr = 0 138 | for msg in msgs: 139 | n = len(self.tokenizer(msg["content"]).input_ids) + 1 # +EOS 140 | if msg["role"] != "assistant": 141 | labels[ptr:ptr+n] = -100 # ignore in loss 142 | ptr += n 143 | sample["input_ids"], sample["labels"] = ids, labels 144 | return sample 145 | 146 | tokenised = data.map(tokenize, remove_columns=data.column_names) 147 | tokenised = tokenised.shuffle(seed=42).train_test_split(test_size=1-self.data_config["train_split"]) 148 | 149 | dataset_train = tokenised["train"] 150 | dataset_eval = tokenised['test'] 151 | 152 | return dataset_train, dataset_eval 153 | 154 | 155 | 156 | 157 | def _prepare_teacher_mode_data(self) -> Tuple[Dataset, Dataset]: 158 | """Generate data using teacher mode (Toolformer-style).""" 159 | logger.info("Generating teacher mode data...") 160 | 161 | data = [] 162 | for _ in range(self.data_config.get("max_samples", 100)): 163 | conversation = self._generate_teacher_mode_example() 164 | data.append(conversation) 165 | 166 | logger.info(f"Generated {len(data)} teacher mode examples") 167 | return self._split_dataset(data) 168 | 169 | def _prepare_manual_template_data(self) -> Tuple[Dataset, Dataset]: 170 | """Generate data from manual templates with paraphrasing.""" 171 | logger.info("Generating data from manual templates...") 172 | 173 | canonical_examples = self._create_canonical_examples() 174 | 175 | bootstrapped_data = [] 176 | for example in canonical_examples: 177 | bootstrapped_data.append(example) 178 | paraphrases = self._simple_paraphrase(example) 179 | bootstrapped_data.extend(paraphrases) 180 | 181 | logger.info(f"Generated {len(bootstrapped_data)} template-based examples") 182 | return self._split_dataset(bootstrapped_data) 183 | 184 | def _generate_synthetic_toolbench_data(self) -> List[Dict[str, Any]]: 185 | """Generate synthetic ToolBench-style data.""" 186 | data = [] 187 | 188 | for tool in self.tools_config: 189 | for i in range(20): 190 | conversation = self._create_tool_conversation(tool) 191 | data.append({"text": conversation, "tool_name": tool["name"]}) 192 | 193 | return data 194 | 195 | def _create_tool_conversation(self, tool: Dict[str, Any]) -> str: 196 | """Create a conversation that uses a specific tool.""" 197 | tool_name = tool["name"] 198 | 199 | user_queries = { 200 | "calculator": ["What's 15 * 24?", "Can you calculate 45 + 67 - 12?"], 201 | "weather": ["What's the weather like in New York?", "Check London weather"], 202 | "search": ["Search for Python tutorials", "Find ML information"] 203 | } 204 | 205 | queries = user_queries.get(tool_name, [f"Use {tool_name}"]) 206 | user_query = random.choice(queries) 207 | 208 | if tool_name == "calculator": 209 | expression = random.choice(["15 * 24", "45 + 67 - 12", "(100 / 5) * 3"]) 210 | params = {"expression": expression} 211 | elif tool_name == "weather": 212 | location = random.choice(["New York", "London", "Tokyo"]) 213 | params = {"location": location} 214 | elif tool_name == "search": 215 | query = random.choice(["Python tutorials", "machine learning"]) 216 | params = {"query": query} 217 | else: 218 | params = {} 219 | 220 | result = self.tool_executor.execute_tool(tool_name, params) 221 | tool_call = json.dumps({"name": tool_name, "parameters": params}) 222 | result_str = json.dumps(result) 223 | 224 | conversation = f"""Human: {user_query} 225 | Assistant: {tool_call} 226 | """ 227 | 228 | return conversation 229 | 230 | def _format_toolbench_data(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: 231 | """Format ToolBench data to our internal format.""" 232 | formatted_data = [] 233 | 234 | for example in data: 235 | conversation = example["conversation"] 236 | tool_name = example["tool_name"] 237 | 238 | # Simple format: just pass through 239 | formatted_data.append({ 240 | "input": conversation, 241 | "output": tool_name 242 | }) 243 | 244 | return formatted_data 245 | 246 | def _split_dataset(self, data: List[Dict[str, Any]]) -> Tuple[Dataset, Dataset]: 247 | """Split data into training and evaluation sets.""" 248 | df = pd.DataFrame(data) 249 | 250 | # 80-20 train-test split 251 | train_df = df.sample(frac=0.8, random_state=42) 252 | test_df = df.drop(train_df.index) 253 | 254 | train_dataset = Dataset.from_pandas(train_df) 255 | test_dataset = Dataset.from_pandas(test_df) 256 | 257 | return train_dataset, test_dataset 258 | 259 | def _generate_base_conversations(self) -> List[Dict[str, Any]]: 260 | """Generate base conversations for teacher mode.""" 261 | base_conversations = [] 262 | 263 | for tool in self.tools_config: 264 | tool_name = tool["name"] 265 | description = tool["description"] 266 | 267 | # Simple static prompts for now 268 | conversation = f"Use the {tool_name} to {description.lower()}" 269 | base_conversations.append({ 270 | "conversation": conversation, 271 | "tool_name": tool_name 272 | }) 273 | 274 | return base_conversations 275 | 276 | def _insert_tool_calls(self, conversation: Dict[str, Any]) -> Optional[Dict[str, Any]]: 277 | """Insert tool calls into a base conversation (teacher model style).""" 278 | tool_name = conversation["tool_name"] 279 | base_convo = conversation["conversation"] 280 | 281 | # Naive insertion of tool call - just for demonstration 282 | if tool_name in base_convo: 283 | return conversation # Nothing to change 284 | 285 | # Insert the tool call before the user query 286 | user_query = f"Please {base_convo}" 287 | tool_call = json.dumps({"name": tool_name, "parameters": {}}) 288 | full_conversation = f"""Human: {user_query} 289 | Assistant: {tool_call} 290 | """ 291 | 292 | return { 293 | "conversation": full_conversation, 294 | "tool_name": tool_name 295 | } 296 | 297 | def _generate_teacher_mode_example(self) -> Dict[str, Any]: 298 | """Generate a teacher mode example with tool insertion.""" 299 | # Start with a base conversation 300 | topics = [ 301 | "I need help with calculations", 302 | "Can you tell me about the weather?", 303 | "I want to search for information" 304 | ] 305 | 306 | topic = random.choice(topics) 307 | tool = random.choice(self.tools_config) 308 | 309 | # Generate conversation with tool insertion 310 | conversation = self._create_tool_conversation(tool) 311 | 312 | return {"text": conversation, "tool_name": tool["name"]} 313 | 314 | def _create_canonical_examples(self) -> List[Dict[str, Any]]: 315 | """Create canonical examples for each tool.""" 316 | examples = [] 317 | 318 | templates = { 319 | "calculator": [ 320 | "Calculate {expression}", 321 | "What is {expression}?", 322 | "Compute {expression}", 323 | "Solve {expression}", 324 | "Find the result of {expression}" 325 | ], 326 | "weather": [ 327 | "What's the weather in {location}?", 328 | "Check weather for {location}", 329 | "Weather forecast for {location}", 330 | "How's the weather in {location}?", 331 | "Tell me about {location} weather" 332 | ], 333 | "search": [ 334 | "Search for {query}", 335 | "Find information about {query}", 336 | "Look up {query}", 337 | "Research {query}", 338 | "Get results for {query}" 339 | ] 340 | } 341 | 342 | for tool in self.tools_config: 343 | tool_name = tool["name"] 344 | tool_templates = templates.get(tool_name, [f"Use {tool_name}"]) 345 | 346 | # Create 10 canonical examples per tool 347 | for i in range(10): 348 | template = random.choice(tool_templates) 349 | 350 | if tool_name == "calculator": 351 | expressions = ["2 + 3", "10 * 5", "100 / 4", "15 - 7", "2 ** 3"] 352 | expression = random.choice(expressions) 353 | user_query = template.format(expression=expression) 354 | params = {"expression": expression} 355 | elif tool_name == "weather": 356 | locations = ["Paris", "Tokyo", "Sydney", "Berlin", "Cairo"] 357 | location = random.choice(locations) 358 | user_query = template.format(location=location) 359 | params = {"location": location} 360 | elif tool_name == "search": 361 | queries = ["Python", "AI", "cooking", "travel", "science"] 362 | query = random.choice(queries) 363 | user_query = template.format(query=query) 364 | params = {"query": query} 365 | else: 366 | user_query = template 367 | params = {} 368 | 369 | result = self.tool_executor.execute_tool(tool_name, params) 370 | tool_call = json.dumps({"name": tool_name, "parameters": params}) 371 | result_str = json.dumps(result) 372 | 373 | conversation = f"""Human: {user_query} 374 | Assistant: I'll help you with that. Let me use the {tool_name} function. 375 | 376 | [TOOL_CALL]{tool_call}[/TOOL_CALL] 377 | 378 | {result_str} 379 | 380 | Based on the result, the answer is {result.get('result', 'processed successfully')}.""" 381 | 382 | examples.append({"text": conversation, "tool_name": tool_name}) 383 | 384 | return examples 385 | 386 | def _simple_paraphrase(self, example: Dict[str, Any]) -> List[Dict[str, Any]]: 387 | """Generate simple paraphrases of an example.""" 388 | paraphrases = [] 389 | original_text = example["text"] 390 | 391 | # Simple paraphrasing by replacing words 392 | replacements = { 393 | "Calculate": "Compute", 394 | "What is": "What's", 395 | "Can you": "Could you", 396 | "Find": "Get", 397 | "Search for": "Look up", 398 | "weather": "forecast", 399 | "information": "info" 400 | } 401 | 402 | # Generate 3 paraphrases 403 | for i in range(3): 404 | paraphrased = original_text 405 | 406 | # Apply random replacements 407 | for original, replacement in replacements.items(): 408 | if original in paraphrased and random.random() < 0.5: 409 | paraphrased = paraphrased.replace(original, replacement) 410 | 411 | if paraphrased != original_text: 412 | paraphrases.append({ 413 | "text": paraphrased, 414 | "tool_name": example["tool_name"] 415 | }) 416 | 417 | return paraphrases 418 | 419 | def _split_dataset(self, data: List[Dict[str, Any]]) -> Tuple[Dataset, Dataset]: 420 | """Split data into train and eval datasets.""" 421 | random.shuffle(data) 422 | 423 | train_split = self.data_config.get("train_split", 0.8) 424 | split_idx = int(len(data) * train_split) 425 | 426 | train_data = data[:split_idx] 427 | eval_data = data[split_idx:] 428 | 429 | # Ensure we have at least some eval data 430 | if len(eval_data) == 0 and len(train_data) > 1: 431 | eval_data = [train_data.pop()] 432 | 433 | train_dataset = Dataset.from_list(train_data) 434 | eval_dataset = Dataset.from_list(eval_data) 435 | 436 | return train_dataset, eval_dataset 437 | 438 | def _log_dataset_sample(self, dataset: Dataset, num_samples: int = 3): 439 | """Log a few samples from the dataset.""" 440 | for i, example in enumerate(dataset): 441 | if i >= num_samples: 442 | break 443 | logger.info(f"Sample {i+1}: {example}") 444 | --------------------------------------------------------------------------------