├── NOTICE ├── Reasoning-through-incentives ├── requirements.txt ├── gvpo_vs_grpo.pdf ├── README_getting_started.md ├── src │ ├── train_gvpo.py │ └── gvpo_trainer.py ├── README.md └── gvpo.md ├── data ├── Restaurant_Dinner_Menu.pdf ├── Restaurant_Childrens_Menu.pdf └── Restaurant_week_specials.pdf ├── notebooks ├── images │ ├── multiagent.png │ ├── agentic_integration.png │ ├── agentic_orchestration.png │ └── bedrock-agent-kb-dynamodb.png ├── multi_agent_collaboration.ipynb └── reasoning_with_langgraph_bedrock_workshop.ipynb ├── scripts ├── __pycache__ │ ├── agent.cpython-311.pyc │ ├── blog_writer.cpython-311.pyc │ └── knowledge_base.cpython-311.pyc ├── agenteval.yml ├── lambda_function.py ├── bedrock.py ├── blog_writer.py ├── agent.py └── knowledge_base.py ├── requirements2.txt ├── requirements.txt ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── README.md └── LICENSE /NOTICE: -------------------------------------------------------------------------------- 1 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | -------------------------------------------------------------------------------- /Reasoning-through-incentives/requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=2.0.0 2 | transformers>=4.30.0 3 | accelerate>=0.20.0 4 | tqdm>=4.65.0 5 | -------------------------------------------------------------------------------- /data/Restaurant_Dinner_Menu.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/agentic-orchestration/main/data/Restaurant_Dinner_Menu.pdf -------------------------------------------------------------------------------- /notebooks/images/multiagent.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/agentic-orchestration/main/notebooks/images/multiagent.png -------------------------------------------------------------------------------- /data/Restaurant_Childrens_Menu.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/agentic-orchestration/main/data/Restaurant_Childrens_Menu.pdf -------------------------------------------------------------------------------- /data/Restaurant_week_specials.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/agentic-orchestration/main/data/Restaurant_week_specials.pdf -------------------------------------------------------------------------------- /notebooks/images/agentic_integration.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/agentic-orchestration/main/notebooks/images/agentic_integration.png -------------------------------------------------------------------------------- /notebooks/images/agentic_orchestration.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/agentic-orchestration/main/notebooks/images/agentic_orchestration.png -------------------------------------------------------------------------------- /scripts/__pycache__/agent.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/agentic-orchestration/main/scripts/__pycache__/agent.cpython-311.pyc -------------------------------------------------------------------------------- /Reasoning-through-incentives/gvpo_vs_grpo.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/agentic-orchestration/main/Reasoning-through-incentives/gvpo_vs_grpo.pdf -------------------------------------------------------------------------------- /notebooks/images/bedrock-agent-kb-dynamodb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/agentic-orchestration/main/notebooks/images/bedrock-agent-kb-dynamodb.png -------------------------------------------------------------------------------- /scripts/__pycache__/blog_writer.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/agentic-orchestration/main/scripts/__pycache__/blog_writer.cpython-311.pyc -------------------------------------------------------------------------------- /scripts/__pycache__/knowledge_base.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/agentic-orchestration/main/scripts/__pycache__/knowledge_base.cpython-311.pyc -------------------------------------------------------------------------------- /requirements2.txt: -------------------------------------------------------------------------------- 1 | duckduckgo-search 2 | grpcio>=1.60.0 3 | grpcio-tools>=1.60.0 4 | python-dotenv 5 | crewai==0.51.1 6 | crewai[tools]>=0.8.3 7 | litellm==1.46.1 8 | unstructured -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | boto3 2 | botocore 3 | setuptools==68.2.2 4 | langchain-aws==0.1.18 5 | ipywidgets 6 | chromadb 7 | typing-extensions 8 | langchain-community==0.2.17 9 | langchain-core==0.2.40 10 | langchain_huggingface==0.0.3 11 | langchain==0.2.16 12 | langgraph==0.2.22 13 | sec_api 14 | opensearch-py 15 | retrying 16 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /scripts/agenteval.yml: -------------------------------------------------------------------------------- 1 | evaluator: 2 | model: claude-3 3 | target: 4 | bedrock_agent_alias_id: {agent_id} 5 | bedrock_agent_id: none 6 | type: bedrock-agent 7 | tests: 8 | check_for_chicken_dinner: 9 | expected_results: 10 | - The agent returns a list of dishes from the dinner menu that contains chicken. 11 | steps: 12 | - Ask agent for a the dishes in the dinner menu that contains chicken. 13 | make_and_check_booking: 14 | steps: 15 | - Ask agent to make a booking for Anna, 2 people, 16 July at 7pm. 16 | - Using the booking ID, check for the booking details 17 | expected_results: 18 | - The agent returns with the booking ID 19 | - The booking details are.. Name Anna, Number of guests 2, Date 16 July, Time 7pm 20 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 13 | 14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already 15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 16 | 17 | * A reproducible test case or series of steps 18 | * The version of our code being used 19 | * Any modifications you've made relevant to the bug 20 | * Anything unusual about your environment or deployment 21 | 22 | 23 | ## Contributing via Pull Requests 24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 25 | 26 | 1. You are working against the latest source on the *main* branch. 27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 29 | 30 | To send us a pull request, please: 31 | 32 | 1. Fork the repository. 33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 34 | 3. Ensure local tests pass. 35 | 4. Commit to your fork using clear commit messages. 36 | 5. Send us a pull request, answering any default questions in the pull request interface. 37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 38 | 39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 41 | 42 | 43 | ## Finding contributions to work on 44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start. 45 | 46 | 47 | ## Code of Conduct 48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 50 | opensource-codeofconduct@amazon.com with any additional questions or comments. 51 | 52 | 53 | ## Security issue notifications 54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 55 | 56 | 57 | ## Licensing 58 | 59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 60 | -------------------------------------------------------------------------------- /Reasoning-through-incentives/README_getting_started.md: -------------------------------------------------------------------------------- 1 | # GVPO Implementation 2 | 3 | Minimal implementation of **Group Variance Policy Optimization (GVPO)** for LLM post-training. 4 | 5 | ## Key Algorithm 6 | 7 | GVPO improves upon GRPO by incorporating KL constraints directly into gradient weights: 8 | 9 | ```python 10 | # GRPO (baseline) 11 | advantages = (rewards - rewards.mean()) / rewards.std() 12 | 13 | # GVPO (this implementation) 14 | advantages = (rewards - rewards.mean()) - beta * (log_ratio - log_ratio.mean()) 15 | loss = -beta * (advantages * log_probs).sum() / (k - 1) 16 | ``` 17 | 18 | ## Mathematical Foundation 19 | 20 | **Weight Formula:** 21 | ``` 22 | w_i = (R(x,y_i) - R̄) - β(log(π_θ/π_θ') - log(π_θ/π_θ')̄) 23 | ``` 24 | 25 | **Loss Function:** 26 | ``` 27 | L = -β * Σ w_i * log π_θ(y_i|x) / (k-1) 28 | ``` 29 | 30 | The `(k-1)` factor is the Bessel correction for unbiased variance estimation. 31 | 32 | ## Installation 33 | 34 | ```bash 35 | pip install -r requirements.txt 36 | ``` 37 | 38 | ## Usage 39 | 40 | ### Basic Training 41 | 42 | ```python 43 | from transformers import AutoModelForCausalLM, AutoTokenizer 44 | from gvpo_trainer import GVPOTrainer 45 | 46 | # Load models 47 | model = AutoModelForCausalLM.from_pretrained("your-model") 48 | ref_model = AutoModelForCausalLM.from_pretrained("your-model") 49 | tokenizer = AutoTokenizer.from_pretrained("your-model") 50 | 51 | # Initialize trainer 52 | trainer = GVPOTrainer( 53 | model=model, 54 | ref_model=ref_model, 55 | tokenizer=tokenizer, 56 | beta=0.1, # KL constraint coefficient 57 | num_samples_per_prompt=4 58 | ) 59 | 60 | # Training step 61 | prompts = ["Solve: 2 + 2 = ?"] 62 | completions = [["4", "Four", "2+2=4", "The answer is 4"]] 63 | rewards = [[1.0, 0.8, 0.9, 0.95]] 64 | 65 | metrics = trainer.train_step(prompts, completions, rewards) 66 | ``` 67 | 68 | ### Run Example 69 | 70 | ```bash 71 | python train_gvpo.py 72 | ``` 73 | 74 | ## Key Differences from GRPO 75 | 76 | | Feature | GRPO | GVPO | 77 | |---------|------|------| 78 | | **Advantage** | `(R - R̄) / σ_R` | `(R - R̄) - β(log π_θ/π_θ' - mean)` | 79 | | **KL Constraint** | External penalty | Built into weights | 80 | | **Normalization** | Std division | Only centering | 81 | | **Stability** | Requires clipping | Inherently stable | 82 | 83 | ## Hyperparameters 84 | 85 | - **beta (β)**: KL constraint strength (default: 0.1) 86 | - Paper shows robustness across [0.01, 0.5] 87 | - Lower β = more exploration 88 | - Higher β = stay closer to reference policy 89 | 90 | - **num_samples_per_prompt (k)**: Responses per prompt (default: 4) 91 | - Paper tests k ∈ [2, 32] 92 | - Higher k improves performance but increases compute 93 | 94 | ## Reward Function 95 | 96 | Replace `dummy_reward_function` in `train_gvpo.py` with your actual reward model: 97 | 98 | ```python 99 | def reward_function(prompt: str, completion: str) -> float: 100 | # For math: use verifier to check correctness 101 | # For general: use reward model (e.g., from RLHF) 102 | return score 103 | ``` 104 | 105 | ## Architecture 106 | 107 | ``` 108 | gvpo_trainer.py # Core GVPO algorithm 109 | train_gvpo.py # Training script 110 | requirements.txt # Dependencies 111 | README.md # Documentation 112 | ``` 113 | 114 | ## Citation 115 | 116 | Based on the GVPO paper's mathematical formulation: 117 | - Zero-sum weight constraint eliminates partition function 118 | - No importance sampling needed (unlike PPO/GRPO) 119 | - Theoretical guarantee of convergence to optimal policy 120 | 121 | ## Performance 122 | 123 | Paper results on Qwen2.5-Math-7B: 124 | 125 | | Benchmark | GRPO | GVPO | Improvement | 126 | |-----------|------|------|-------------| 127 | | AIME2024 | 14.79 | 20.72 | +40% | 128 | | AMC | 55.42 | 62.65 | +13% | 129 | | MATH500 | 80.00 | 83.80 | +5% | 130 | 131 | ## License 132 | 133 | MIT 134 | -------------------------------------------------------------------------------- /Reasoning-through-incentives/src/train_gvpo.py: -------------------------------------------------------------------------------- 1 | """ 2 | Example training script for GVPO 3 | """ 4 | 5 | import torch 6 | from transformers import AutoModelForCausalLM, AutoTokenizer 7 | from gvpo_trainer import GVPOTrainer 8 | from typing import List, Tuple 9 | 10 | 11 | def dummy_reward_function(prompt: str, completion: str) -> float: 12 | """ 13 | Placeholder reward function - replace with actual reward model 14 | For math problems, this would be a verifier checking correctness 15 | """ 16 | # Simple heuristic: reward longer, more detailed responses 17 | return len(completion.split()) / 100.0 18 | 19 | 20 | def prepare_training_data() -> List[Tuple[str, List[str], List[float]]]: 21 | """ 22 | Prepare training data in format: (prompt, [completions], [rewards]) 23 | In practice, load from dataset 24 | """ 25 | # Example math prompts 26 | prompts = [ 27 | "Solve: What is 15 * 23?", 28 | "Find the derivative of f(x) = x^2 + 3x + 2", 29 | "Calculate the area of a circle with radius 5" 30 | ] 31 | 32 | return [(p, [], []) for p in prompts] # Completions generated during training 33 | 34 | 35 | def main(): 36 | # Configuration 37 | model_name = "gpt2" # Replace with your model 38 | beta = 0.1 39 | num_samples = 4 40 | num_epochs = 3 41 | batch_size = 2 42 | 43 | print(f"Loading model: {model_name}") 44 | tokenizer = AutoTokenizer.from_pretrained(model_name) 45 | if tokenizer.pad_token is None: 46 | tokenizer.pad_token = tokenizer.eos_token 47 | 48 | # Load model and reference model 49 | model = AutoModelForCausalLM.from_pretrained(model_name) 50 | ref_model = AutoModelForCausalLM.from_pretrained(model_name) 51 | 52 | # Initialize GVPO trainer 53 | trainer = GVPOTrainer( 54 | model=model, 55 | ref_model=ref_model, 56 | tokenizer=tokenizer, 57 | beta=beta, 58 | num_samples_per_prompt=num_samples, 59 | learning_rate=1e-6 60 | ) 61 | 62 | print(f"\nGVPO Training Configuration:") 63 | print(f" Beta (β): {beta}") 64 | print(f" Samples per prompt (k): {num_samples}") 65 | print(f" Epochs: {num_epochs}") 66 | print(f" Batch size: {batch_size}\n") 67 | 68 | # Training loop 69 | training_data = prepare_training_data() 70 | 71 | for epoch in range(num_epochs): 72 | print(f"Epoch {epoch + 1}/{num_epochs}") 73 | epoch_metrics = {"loss": [], "kl": [], "advantage": []} 74 | 75 | for i in range(0, len(training_data), batch_size): 76 | batch = training_data[i:i + batch_size] 77 | prompts = [item[0] for item in batch] 78 | 79 | # Generate completions 80 | print(f" Generating {num_samples} completions per prompt...") 81 | completions = trainer.generate_completions(prompts) 82 | 83 | # Compute rewards 84 | print(f" Computing rewards...") 85 | rewards = [] 86 | for prompt, comps in zip(prompts, completions): 87 | prompt_rewards = [dummy_reward_function(prompt, c) for c in comps] 88 | rewards.append(prompt_rewards) 89 | 90 | # Training step 91 | print(f" Training step...") 92 | metrics = trainer.train_step(prompts, completions, rewards) 93 | 94 | epoch_metrics["loss"].append(metrics["loss"]) 95 | epoch_metrics["kl"].append(metrics["mean_kl"]) 96 | epoch_metrics["advantage"].append(metrics["mean_advantage"]) 97 | 98 | print(f" Loss: {metrics['loss']:.4f} | " 99 | f"KL: {metrics['mean_kl']:.4f} | " 100 | f"Advantage: {metrics['mean_advantage']:.4f}") 101 | 102 | # Epoch summary 103 | avg_loss = sum(epoch_metrics["loss"]) / len(epoch_metrics["loss"]) 104 | avg_kl = sum(epoch_metrics["kl"]) / len(epoch_metrics["kl"]) 105 | print(f" Epoch {epoch + 1} Summary: Loss={avg_loss:.4f}, KL={avg_kl:.4f}\n") 106 | 107 | # Save model 108 | output_dir = "./gvpo_model" 109 | print(f"Saving model to {output_dir}") 110 | model.save_pretrained(output_dir) 111 | tokenizer.save_pretrained(output_dir) 112 | print("Training complete!") 113 | 114 | 115 | if __name__ == "__main__": 116 | main() 117 | -------------------------------------------------------------------------------- /scripts/lambda_function.py: -------------------------------------------------------------------------------- 1 | import json 2 | import uuid 3 | import boto3 4 | 5 | dynamodb = boto3.resource('dynamodb') 6 | table = dynamodb.Table('restaurant_bookings') 7 | 8 | def get_named_parameter(event, name): 9 | """ 10 | Get a parameter from the lambda event 11 | """ 12 | return next(item for item in event['parameters'] if item['name'] == name)['value'] 13 | 14 | 15 | def get_booking_details(booking_id): 16 | """ 17 | Retrieve details of a restaurant booking 18 | 19 | Args: 20 | booking_id (string): The ID of the booking to retrieve 21 | """ 22 | try: 23 | response = table.get_item(Key={'booking_id': booking_id}) 24 | if 'Item' in response: 25 | return response['Item'] 26 | else: 27 | return {'message': f'No booking found with ID {booking_id}'} 28 | except Exception as e: 29 | return {'error': str(e)} 30 | 31 | 32 | def create_booking(date, name, hour, num_guests): 33 | """ 34 | Create a new restaurant booking 35 | 36 | Args: 37 | date (string): The date of the booking 38 | name (string): Name to idenfity your reservation 39 | hour (string): The hour of the booking 40 | num_guests (integer): The number of guests for the booking 41 | """ 42 | try: 43 | booking_id = str(uuid.uuid4())[:8] 44 | table.put_item( 45 | Item={ 46 | 'booking_id': booking_id, 47 | 'date': date, 48 | 'name': name, 49 | 'hour': hour, 50 | 'num_guests': num_guests 51 | } 52 | ) 53 | return {'booking_id': booking_id} 54 | except Exception as e: 55 | return {'error': str(e)} 56 | 57 | 58 | def delete_booking(booking_id): 59 | """ 60 | Delete an existing restaurant booking 61 | 62 | Args: 63 | booking_id (str): The ID of the booking to delete 64 | """ 65 | try: 66 | response = table.delete_item(Key={'booking_id': booking_id}) 67 | if response['ResponseMetadata']['HTTPStatusCode'] == 200: 68 | return {'message': f'Booking with ID {booking_id} deleted successfully'} 69 | else: 70 | return {'message': f'Failed to delete booking with ID {booking_id}'} 71 | except Exception as e: 72 | return {'error': str(e)} 73 | 74 | 75 | def lambda_handler(event, context): 76 | # get the action group used during the invocation of the lambda function 77 | actionGroup = event.get('actionGroup', '') 78 | 79 | # name of the function that should be invoked 80 | function = event.get('function', '') 81 | 82 | # parameters to invoke function with 83 | parameters = event.get('parameters', []) 84 | 85 | if function == 'get_booking_details': 86 | booking_id = get_named_parameter(event, "booking_id") 87 | if booking_id: 88 | response = str(get_booking_details(booking_id)) 89 | responseBody = {'TEXT': {'body': json.dumps(response)}} 90 | else: 91 | responseBody = {'TEXT': {'body': 'Missing booking_id parameter'}} 92 | 93 | elif function == 'create_booking': 94 | date = get_named_parameter(event, "date") 95 | name = get_named_parameter(event, "name") 96 | hour = get_named_parameter(event, "hour") 97 | num_guests = get_named_parameter(event, "num_guests") 98 | 99 | if date and hour and num_guests: 100 | response = str(create_booking(date, name, hour, num_guests)) 101 | responseBody = {'TEXT': {'body': json.dumps(response)}} 102 | else: 103 | responseBody = {'TEXT': {'body': 'Missing required parameters'}} 104 | 105 | elif function == 'delete_booking': 106 | booking_id = get_named_parameter(event, "booking_id") 107 | if booking_id: 108 | response = str(delete_booking(booking_id)) 109 | responseBody = {'TEXT': {'body': json.dumps(response)}} 110 | else: 111 | responseBody = {'TEXT': {'body': 'Missing booking_id parameter'}} 112 | 113 | else: 114 | responseBody = {'TEXT': {'body': 'Invalid function'}} 115 | 116 | action_response = { 117 | 'actionGroup': actionGroup, 118 | 'function': function, 119 | 'functionResponse': { 120 | 'responseBody': responseBody 121 | } 122 | } 123 | 124 | function_response = {'response': action_response, 'messageVersion': event['messageVersion']} 125 | print("Response: {}".format(function_response)) 126 | 127 | return function_response 128 | -------------------------------------------------------------------------------- /scripts/bedrock.py: -------------------------------------------------------------------------------- 1 | import os 2 | import boto3 3 | from langchain_aws import BedrockLLM, ChatBedrock, ChatBedrockConverse, BedrockEmbeddings 4 | #from langchain.memory import ConversationBufferMemory, ReadOnlySharedMemory 5 | from botocore.config import Config 6 | from botocore.exceptions import ClientError 7 | 8 | bedrock_agent_runtime_client = boto3.client('bedrock-agent-runtime') 9 | 10 | ## Setup LLMs 11 | def get_llm(model_id: str, aws_region: str='us-west-2',): 12 | config = Config( 13 | retries = dict( 14 | max_attempts = 10, 15 | total_max_attempts = 25, 16 | ) 17 | ) 18 | bedrock_client = boto3.client("bedrock-runtime", config=config, region_name=aws_region) 19 | 20 | inference_modifier = { 21 | "max_tokens": 4096, 22 | "temperature": 0.01, 23 | "top_k": 50, 24 | "top_p": 0.95, 25 | "stop_sequences": ["\n\n\nHuman"], 26 | } 27 | 28 | if 'claude-3-5' in model_id: 29 | inference_modifier = { 30 | "max_tokens": 4096, 31 | "temperature": 0.01, 32 | "top_k": 50, 33 | "top_p": 0.95, 34 | "stop_sequences": ["\n\n\nHuman"], 35 | } 36 | llm = ChatBedrock( 37 | model_id=model_id, 38 | client=bedrock_client, 39 | model_kwargs=inference_modifier, 40 | region_name=aws_region, 41 | ) 42 | elif 'claude-3' in model_id or 'mistral' in model_id or 'llama3-1' in model_id: 43 | llm = ChatBedrockConverse( 44 | model=model_id, 45 | client=bedrock_client, 46 | temperature=0.01, 47 | max_tokens=2048 if 'llama3-1' in model_id else 4096, 48 | region_name=aws_region, 49 | ) 50 | else: 51 | llm = BedrockLLM( 52 | model_id=model_id, 53 | client=bedrock_client, 54 | model_kwargs={"temperature": 0.1, "max_gen_len":4096}, 55 | ) 56 | 57 | return llm 58 | 59 | def get_embedding(model_id: str="amazon.titan-embed-text-v2:0", aws_region: str='us-west-2'): 60 | config = Config( 61 | retries = dict( 62 | max_attempts = 10, 63 | total_max_attempts = 25, 64 | ) 65 | ) 66 | bedrock_client = boto3.client("bedrock-runtime", config=config, region_name=aws_region) 67 | 68 | return BedrockEmbeddings(client = bedrock_client, region_name=aws_region, model_id=model_id) 69 | 70 | def check_and_delete_iam_policy(policy_name): 71 | # Create an IAM client 72 | iam = boto3.client('iam') 73 | 74 | try: 75 | # Try to get the policy 76 | response = iam.get_policy(PolicyArn=f'arn:aws:iam::aws:policy/{policy_name}') 77 | 78 | # If we reach here, the policy exists 79 | print(f"Policy '{policy_name}' exists. Attempting to delete...") 80 | 81 | # First, we need to detach the policy from all entities 82 | detach_policy(iam, response['Policy']['Arn']) 83 | 84 | # Now we can delete the policy 85 | iam.delete_policy(PolicyArn=response['Policy']['Arn']) 86 | print(f"Policy '{policy_name}' has been deleted successfully.") 87 | 88 | except ClientError as e: 89 | if e.response['Error']['Code'] == 'NoSuchEntity': 90 | print(f"Policy '{policy_name}' does not exist.") 91 | else: 92 | print(f"An error occurred: {e}") 93 | 94 | def detach_policy(iam, policy_arn): 95 | # Detach from users 96 | for user in iam.list_entities_for_policy(PolicyArn=policy_arn, EntityFilter='User')['PolicyUsers']: 97 | iam.detach_user_policy(UserName=user['UserName'], PolicyArn=policy_arn) 98 | 99 | # Detach from groups 100 | for group in iam.list_entities_for_policy(PolicyArn=policy_arn, EntityFilter='Group')['PolicyGroups']: 101 | iam.detach_group_policy(GroupName=group['GroupName'], PolicyArn=policy_arn) 102 | 103 | # Detach from roles 104 | for role in iam.list_entities_for_policy(PolicyArn=policy_arn, EntityFilter='Role')['PolicyRoles']: 105 | iam.detach_role_policy(RoleName=role['RoleName'], PolicyArn=policy_arn) 106 | 107 | def check_table_exists(table_name): 108 | # Create a DynamoDB client 109 | dynamodb = boto3.client('dynamodb') 110 | 111 | try: 112 | # Try to describe the table 113 | response = dynamodb.describe_table(TableName=table_name) 114 | return True 115 | except : 116 | return False 117 | 118 | def check_lambda_function_exists(function_name): 119 | # Create a Lambda client 120 | lambda_client = boto3.client('lambda') 121 | try: 122 | # Try to get the function configuration 123 | lambda_client.get_function(FunctionName=function_name) 124 | return True 125 | except : 126 | return False 127 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Reasoning Orchestration Workshop using Amazon Bedrock, Knowledge Bases, and Langchain 2 | This repository provides step-by-step instructions for creating a collaborative multi-agent system that leverages graph-based orchestration. It demonstrates how to combine Amazon Bedrock Agent with agents developed using open-source frameworks, enabling enhanced reasoning capabilities and seamless integration. The exercise will guide you through the process of building a reasoning orchestration system using [Amazon Bedrock](https://aws.amazon.com/bedrock/), [Knowledge Bases for Amazon Bedrock](https://aws.amazon.com/bedrock/knowledge-bases/) and [Agents for Amazon Bedrock](https://aws.amazon.com/bedrock/agents/) as well as customer fine tuned models. We will also explore the integration of Bedrock agents with open source orchestration frameworks [LangGgraph](https://langchain-ai.github.io/langgraph/) and [CrewAI](https://github.com/crewAIInc/crewAI) for dispatching and reasoning. 3 | 4 | 5 |
6 | 7 | ## Overview 8 | In this workshop, you will learn how to: 9 | 10 | 1. Build a multimodal agentic orchestration framework using AWS and open source tools 11 | 3. Set up and configure Amazon Bedrock, a foundation for building large language models (LLMs) and other AI-powered applications, including Agent and Knowledge Bases. 12 | 5. Set up a open source RAG solution using Chroma and an embedding engine of your choice. 13 | 6. Utilize Langchain, a framework for building applications with large language models, to orchestrate the reasoning process. 14 | 7. Integrate Langgraph, a tool for managing agentic services, to dispatch and reason about the various components of your system. 15 | 8. Integrate open source langgraph with Amazon Bedrock Agent which is acssociated with Amazon Lambda and Amazon Bedrock Knowledge Bases 16 | 17 | ## Prerequisites 18 | Workshop practioners are expected to have LLM, Jupyter Notebook and Python working experience. 19 | 20 | * AWS account with appropriate permissions to access the services used in this workshop, include Bedrock models access and IAM permission to access S3, DynamoDB, Lambda and others. 21 | * Basic understanding of large language models, knowledge management, and AI-powered applications. 22 | * Familiarity with Python programming language. 23 | 24 | ## Getting Started 25 | 26 | Set up Amazon Bedrock: Follow the official Amazon Bedrock documentation to create your Bedrock environment and configure the necessary permissions and resources. 27 | 28 | Integrate Knowledge Bases: Explore the Knowledge Bases service and learn how to integrate it with your Bedrock-powered application. Ensure that your knowledge base is populated with relevant information to support your reasoning tasks. 29 | 30 | Utilize Langchain: Dive into the Langchain framework and understand how to use it to orchestrate the reasoning process. Explore the various Langchain components, such as agents, chains, and prompts, to build your reasoning system. 31 | 32 | Integrate Langgraph: Introduce Langgraph into your system to manage the agentic services involved in the reasoning process. Learn how to dispatch tasks and reason about the various components of your system using Langgraph. 33 | 34 | Develop your Reasoning Orchestration System: Combine the knowledge and tools you've acquired to build your reasoning orchestration system. Ensure that the different components (Bedrock, Knowledge Bases, Langchain, and Langgraph) work seamlessly together to provide the desired functionality. 35 | 36 | Test and Refine: Thoroughly test your reasoning orchestration system, and make any necessary adjustments to improve its performance and reliability. 37 | 38 | ## Model fine-tune 39 | GVPO (Group Variance Policy Optimization) addresses the training instability issues of GRPO while providing stronger theoretical guarantees. The key innovation is incorporating the analytical solution to KL-constrained reward maximization directly into gradient weights through a clever zero-sum weight constraint that eliminates the intractable partition function. 40 | 41 | ## Definitions 42 | 43 | ## Agents or Agentic Services: 44 | Agents are created to fulfill specific roles and responsibilities. Each agent has predefined capabilities and embedded intelligence that enables them to focus on executing particular components or aspects of a project. 45 | 46 | ### Key Elements of an Agent 47 | - Planning: By prompting elicits reasoning in large language models, AI agents can engage in complex problem-solving and planning processes. 48 | - Reflection: This involves iterative refinement with self-feedback, allowing AI agents to continuously improve their performance. 49 | - Role Playing: Each agent is assigned a specific role, such as a researcher or analyst, to focus their efforts. 50 | - Tasks: Agents maintain attention on their assigned tasks, ensuring efficient execution. 51 | - Tools: Equipped with various tools for data retrieval, processing, and interaction. 52 | - Colaboration: Agents collaborate with each other to complete tasks. 53 | - Guardrails: Safety measures and protocols to ensure reliable and ethical operations. 54 | - Memory: Ability to store and recall past interactions and data, enhancing decision-making. 55 | 56 | ### Multi-Agent Collaboration 57 | Multi-agent collaboration involves multiple agents working together to achieve complex goals. This collaboration can take various forms depending on the nature and requirements of the tasks. Effective collaboration ensures that tasks are completed efficiently and accurately. Here are the few typical types: 58 | 59 | - Sequential: Agents perform tasks in a predetermined order, where each step depends on the completion of the previous one. 60 | - Hierarchical: Agents follow a structured hierarchy, with higher-level agents overseeing and coordinating the activities of lower-level agents. 61 | - Asynchronous: Agents operate independently, handling tasks as they arise without adhering to a fixed sequence, allowing for flexibility and parallel processing. 62 | 63 | ## Resources 64 | * [Amazon Bedrock Documentation](https://docs.aws.amazon.com/bedrock/) 65 | * [Amazon OpenSearch Documentation](https://docs.aws.amazon.com/opensearch-service/) 66 | * [Amazon Bedrock Agent Document](https://docs.aws.amazon.com/bedrock/latest/userguide/agents.html) 67 | * [Langchain Agent Documentation](https://python.langchain.com/v0.1/docs/modules/agents/) 68 | * [Langgraph Documentation](https://langchain-ai.github.io/langgraph/) 69 | * [CrewAI Framwork](https://github.com/crewAIInc/crewAI) 70 | 71 | ## Conclusion 72 | By the end of this workshop, you will have a solid understanding of how to build a agentic orchestration system using Amazon Bedrock, Knowledge Bases, Langchain, and Langgraph. This knowledge will enable you to create powerful genreatiev AI-powered applications that can effectively reason about complex problems and make informed decisions. 73 | 74 | 75 | -------------------------------------------------------------------------------- /scripts/blog_writer.py: -------------------------------------------------------------------------------- 1 | import os 2 | from crewai import Agent, Task, Crew, Process 3 | from dotenv import load_dotenv 4 | from langchain.tools import DuckDuckGoSearchRun 5 | from langchain_community.tools.tavily_search import TavilySearchResults 6 | import boto3 7 | from langchain_aws import BedrockLLM, ChatBedrock, ChatBedrockConverse 8 | #from langchain.memory import ConversationBufferMemory, ReadOnlySharedMemory 9 | from botocore.config import Config 10 | from bedrock import * 11 | from crewai_tools import ( 12 | DirectoryReadTool, 13 | FileReadTool, 14 | SerperDevTool, 15 | WebsiteSearchTool 16 | ) 17 | 18 | load_dotenv() 19 | duck_search_tool = DuckDuckGoSearchRun() 20 | tavily_tool = TavilySearchResults(max_results=5) 21 | web_rag_tool = WebsiteSearchTool() 22 | 23 | 24 | class blogAgents(): 25 | def __init__(self, topic, model_id): 26 | self.topic = topic 27 | self.model_id = model_id 28 | 29 | def planner(self, topic, model_id): 30 | return Agent( 31 | role="Content Planner", 32 | goal=f"""lan engaging and factually accurate content on {topic}""", 33 | backstory=f"""You're working on planning a blog article about the topic: {topic}. \n 34 | You collect information by searhing the web for the latest developements that directly relate to the {topic}. \n 35 | audience learn something and make informed decisions. Your work is the basis for the Content Writer to write an article on this {topic}.""", 36 | allow_delegation=False, 37 | tools=[duck_search_tool,tavily_tool, web_rag_tool], 38 | llm=get_llm(model_id), 39 | verbose=True 40 | ) 41 | 42 | def writer(self, topic, model_id): 43 | return Agent( 44 | role="Content Writer", 45 | goal=f"Write insightful and factually accurate opinion piece about the topic: {topic}", 46 | backstory=f"""You're working on a writing a new opinion piece about the topic: {topic}. You base your writing on the work of \n 47 | the Content Planner, who provides an outline \n 48 | and relevant context about the topic. \n 49 | You follow the main objectives and \n 50 | direction of the outline, \n 51 | as provide by the Content Planner. \n 52 | You also provide objective and impartial insights \n 53 | and back them up with information \n 54 | provide by the Content Planner. \n 55 | You acknowledge in your opinion piece \n 56 | when your statements are opinions \n 57 | as opposed to objective statements.""", 58 | allow_delegation=False, 59 | llm=get_llm(model_id), 60 | verbose=True 61 | ) 62 | 63 | def editor(self, model_id): 64 | return Agent( 65 | role="Editor", 66 | goal="Edit a given blog post to align with " 67 | "the writing style of the organization. ", 68 | backstory="You are an editor who receives a blog post from the Content Writer. " 69 | "Your goal is to review the blog post to ensure that it follows journalistic best practices," 70 | "provides balanced viewpoints when providing opinions or assertions, " 71 | "and also avoids major controversial topics or opinions when possible.", 72 | allow_delegation=False, 73 | llm=get_llm(model_id), 74 | verbose=True 75 | ) 76 | 77 | 78 | class blogTasks(): 79 | def __init__(self, topic): 80 | self.topic = topic 81 | 82 | def plan(self, planner, topic): 83 | return Task( 84 | description=( 85 | f"""1. Prioritize the latest trends, key players, and noteworthy news on {topic}.\n 86 | 2. Identify the target audience, considering their interests and pain points.\n 87 | 3. Develop a detailed content outline including an introduction, key points, and a call to action.\n 88 | 4. Include SEO keywords and relevant data or sources.""" 89 | ), 90 | expected_output=f"""Covert the latest developments of the {topic} with sufficient depth as a domain expert. 91 | A comprehensive content plan document with an outline, audience analysis, 92 | SEO keywords, and resources.""", 93 | agent=planner, 94 | ) 95 | def write(self, writer, topic): 96 | return Task( 97 | description=( 98 | f"""1. Use the content plan to craft a compelling blog post on {topic}.\n 99 | 2. Incorporate SEO keywords naturally.\n 100 | 3. Sections/Subtitles are properly named in an engaging manner.\n 101 | 4. Ensure the post is structured with an engaging introduction, insightful body, and a summarizing conclusion.\n 102 | 5. Proofread for grammatical errors and alignment with the brand's voice""" 103 | ), 104 | expected_output="A well-written blog post like a professional writer." 105 | "You are a domain expert and your blog is for other subject experts" 106 | "in markdown format, ready for publication, " 107 | "each section should have 2 or 3 paragraphs.", 108 | agent=writer, 109 | ) 110 | 111 | def edit(self, editor): 112 | return Task( 113 | description=("Proofread the given blog post for " 114 | "grammatical errors and " 115 | "alignment with the brand's voice."), 116 | expected_output="A well-written blog post in markdown format, " 117 | "ready for publication, " 118 | "each section should have 2 or 3 paragraphs.", 119 | agent=editor, 120 | file_path='./blogPost.txt', 121 | ) 122 | 123 | class blogCrew(): 124 | def __init__(self, topic, model_id): 125 | self.topic = topic 126 | self.model_id = model_id 127 | 128 | def run(self): 129 | agents = blogAgents(self.topic, self.model_id) 130 | tasks = blogTasks(self.topic) 131 | 132 | planner_agent = agents.planner(self.topic, self.model_id) 133 | writer_agent = agents.writer(self.topic, self.model_id) 134 | editor_agent = agents.editor(self.model_id) 135 | 136 | plan_task = tasks.plan(planner_agent, self.topic) 137 | write_task = tasks.write(writer_agent, self.topic) 138 | edit_task = tasks.edit(editor_agent) 139 | 140 | 141 | crew = Crew( 142 | agents=[planner_agent, writer_agent, editor_agent], 143 | tasks=[plan_task, write_task, edit_task], 144 | verbose=True, 145 | memory=True, 146 | embedder={ 147 | "provider": "huggingface", 148 | "config": {"model": "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"}, 149 | }, 150 | cache=True, 151 | process=Process.sequential # Sequential process will have tasks executed one after the other and the outcome of the previous one is 152 | ) 153 | 154 | result = crew.kickoff() 155 | return result -------------------------------------------------------------------------------- /Reasoning-through-incentives/src/gvpo_trainer.py: -------------------------------------------------------------------------------- 1 | """ 2 | GVPO (Group Variance Policy Optimization) Trainer 3 | Based on the paper's mathematical formulation and TRL's architecture 4 | """ 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | from typing import Optional, Dict, List, Union 9 | from transformers import PreTrainedModel, PreTrainedTokenizer 10 | from torch.utils.data import DataLoader 11 | from tqdm import tqdm 12 | 13 | 14 | class GVPOTrainer: 15 | """ 16 | GVPO Trainer implementing the core algorithm: 17 | 18 | Key Formula: 19 | w_i = (R(x,y_i) - R̄) - β(log(π_θ/π_θ') - log(π_θ/π_θ')̄) 20 | Loss = -β * Σ w_i * log π_θ(y_i|x) / (k-1) 21 | """ 22 | 23 | def __init__( 24 | self, 25 | model: PreTrainedModel, 26 | ref_model: PreTrainedModel, 27 | tokenizer: PreTrainedTokenizer, 28 | beta: float = 0.1, 29 | num_samples_per_prompt: int = 4, 30 | max_length: int = 512, 31 | learning_rate: float = 1e-6, 32 | device: str = "cuda" if torch.cuda.is_available() else "cpu" 33 | ): 34 | self.model = model.to(device) 35 | self.ref_model = ref_model.to(device) 36 | self.ref_model.eval() # Reference model stays frozen 37 | 38 | self.tokenizer = tokenizer 39 | self.beta = beta 40 | self.k = num_samples_per_prompt 41 | self.max_length = max_length 42 | self.device = device 43 | 44 | self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=learning_rate) 45 | 46 | def compute_log_probs( 47 | self, 48 | model: PreTrainedModel, 49 | input_ids: torch.Tensor, 50 | attention_mask: torch.Tensor 51 | ) -> torch.Tensor: 52 | """Compute log probabilities for sequences""" 53 | with torch.no_grad() if model == self.ref_model else torch.enable_grad(): 54 | outputs = model(input_ids=input_ids, attention_mask=attention_mask) 55 | logits = outputs.logits 56 | 57 | # Shift for next-token prediction 58 | shift_logits = logits[..., :-1, :].contiguous() 59 | shift_labels = input_ids[..., 1:].contiguous() 60 | 61 | # Compute log probs 62 | log_probs = F.log_softmax(shift_logits, dim=-1) 63 | token_log_probs = torch.gather( 64 | log_probs, 65 | dim=-1, 66 | index=shift_labels.unsqueeze(-1) 67 | ).squeeze(-1) 68 | 69 | # Mask padding tokens 70 | mask = attention_mask[..., 1:].contiguous() 71 | token_log_probs = token_log_probs * mask 72 | 73 | # Sum over sequence length 74 | sequence_log_probs = token_log_probs.sum(dim=-1) 75 | 76 | return sequence_log_probs 77 | 78 | def compute_gvpo_loss( 79 | self, 80 | prompts: List[str], 81 | completions: List[List[str]], 82 | rewards: List[List[float]] 83 | ) -> Dict[str, torch.Tensor]: 84 | """ 85 | Core GVPO loss computation 86 | 87 | Args: 88 | prompts: List of prompts [batch_size] 89 | completions: List of k completions per prompt [batch_size, k] 90 | rewards: List of k rewards per prompt [batch_size, k] 91 | """ 92 | batch_size = len(prompts) 93 | total_loss = 0.0 94 | stats = {"loss": [], "advantages": [], "kl_div": []} 95 | 96 | for i in range(batch_size): 97 | prompt = prompts[i] 98 | k_completions = completions[i] 99 | k_rewards = torch.tensor(rewards[i], device=self.device) 100 | 101 | # Tokenize prompt + completions 102 | full_texts = [prompt + comp for comp in k_completions] 103 | encodings = self.tokenizer( 104 | full_texts, 105 | padding=True, 106 | truncation=True, 107 | max_length=self.max_length, 108 | return_tensors="pt" 109 | ).to(self.device) 110 | 111 | # Compute log probs from current and reference models 112 | log_probs_new = self.compute_log_probs( 113 | self.model, 114 | encodings.input_ids, 115 | encodings.attention_mask 116 | ) 117 | log_probs_old = self.compute_log_probs( 118 | self.ref_model, 119 | encodings.input_ids, 120 | encodings.attention_mask 121 | ) 122 | 123 | # Compute log ratio: log(π_θ/π_θ') 124 | log_ratio = log_probs_new - log_probs_old 125 | 126 | # GVPO advantage computation (key difference from GRPO) 127 | # w_i = (R_i - R̄) - β((log_ratio_i - log_ratio_mean)) 128 | reward_centered = k_rewards - k_rewards.mean() 129 | log_ratio_centered = log_ratio - log_ratio.mean() 130 | 131 | advantages = reward_centered - self.beta * log_ratio_centered 132 | 133 | # GVPO loss with Bessel correction (k-1) 134 | # Loss = -β * Σ w_i * log π_θ(y_i|x) / (k-1) 135 | loss = -self.beta * (advantages * log_probs_new).sum() / (self.k - 1) 136 | 137 | total_loss += loss 138 | 139 | # Track statistics 140 | stats["loss"].append(loss.item()) 141 | stats["advantages"].append(advantages.mean().item()) 142 | stats["kl_div"].append(log_ratio.mean().item()) 143 | 144 | avg_loss = total_loss / batch_size 145 | 146 | return { 147 | "loss": avg_loss, 148 | "mean_advantage": torch.tensor(stats["advantages"]).mean(), 149 | "mean_kl": torch.tensor(stats["kl_div"]).mean() 150 | } 151 | 152 | def train_step( 153 | self, 154 | prompts: List[str], 155 | completions: List[List[str]], 156 | rewards: List[List[float]] 157 | ) -> Dict[str, float]: 158 | """Single training step""" 159 | self.model.train() 160 | self.optimizer.zero_grad() 161 | 162 | metrics = self.compute_gvpo_loss(prompts, completions, rewards) 163 | loss = metrics["loss"] 164 | 165 | loss.backward() 166 | self.optimizer.step() 167 | 168 | return {k: v.item() if torch.is_tensor(v) else v for k, v in metrics.items()} 169 | 170 | def generate_completions( 171 | self, 172 | prompts: List[str], 173 | num_return_sequences: Optional[int] = None 174 | ) -> List[List[str]]: 175 | """Generate k completions per prompt""" 176 | if num_return_sequences is None: 177 | num_return_sequences = self.k 178 | 179 | self.model.eval() 180 | all_completions = [] 181 | 182 | with torch.no_grad(): 183 | for prompt in prompts: 184 | inputs = self.tokenizer( 185 | prompt, 186 | return_tensors="pt", 187 | truncation=True, 188 | max_length=self.max_length 189 | ).to(self.device) 190 | 191 | outputs = self.model.generate( 192 | **inputs, 193 | max_new_tokens=256, 194 | num_return_sequences=num_return_sequences, 195 | do_sample=True, 196 | temperature=0.7, 197 | top_p=0.9, 198 | pad_token_id=self.tokenizer.pad_token_id 199 | ) 200 | 201 | completions = [ 202 | self.tokenizer.decode(out[inputs.input_ids.shape[1]:], skip_special_tokens=True) 203 | for out in outputs 204 | ] 205 | all_completions.append(completions) 206 | 207 | return all_completions 208 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | -------------------------------------------------------------------------------- /scripts/agent.py: -------------------------------------------------------------------------------- 1 | import boto3 2 | import json 3 | import time 4 | import zipfile 5 | from io import BytesIO 6 | 7 | iam_client = boto3.client('iam') 8 | sts_client = boto3.client('sts') 9 | session = boto3.session.Session() 10 | region = session.region_name 11 | account_id = sts_client.get_caller_identity()["Account"] 12 | dynamodb_client = boto3.client('dynamodb') 13 | dynamodb_resource = boto3.resource('dynamodb') 14 | lambda_client = boto3.client('lambda') 15 | bedrock_agent_client = boto3.client('bedrock-agent') 16 | 17 | 18 | def create_dynamodb(table_name): 19 | table = dynamodb_resource.create_table( 20 | TableName=table_name, 21 | KeySchema=[ 22 | { 23 | 'AttributeName': 'booking_id', 24 | 'KeyType': 'HASH' 25 | } 26 | ], 27 | AttributeDefinitions=[ 28 | { 29 | 'AttributeName': 'booking_id', 30 | 'AttributeType': 'S' 31 | } 32 | ], 33 | BillingMode='PAY_PER_REQUEST' # Use on-demand capacity mode 34 | ) 35 | 36 | # Wait for the table to be created 37 | print(f'Creating table {table_name}...') 38 | table.wait_until_exists() 39 | print(f'Table {table_name} created successfully!') 40 | return 41 | 42 | def get_agent_lambda_role_arn(role_name): 43 | # Create an IAM client 44 | iam_client = boto3.client('iam') 45 | 46 | try: 47 | # Get the role information 48 | response = iam_client.get_role(RoleName=role_name) 49 | 50 | # Extract the ARN from the response 51 | role_arn = response['Role']['Arn'] 52 | 53 | return response 54 | 55 | except iam_client.exceptions.NoSuchEntityException: 56 | print(f"Role '{role_name}' not found.") 57 | return None 58 | 59 | except Exception as e: 60 | print(f"An error occurred: {str(e)}") 61 | return None 62 | 63 | 64 | def create_lambda(lambda_function_name, lambda_iam_role): 65 | # add to function 66 | 67 | # Package up the lambda function code 68 | s = BytesIO() 69 | z = zipfile.ZipFile(s, 'w') 70 | z.write("lambda_function.py") 71 | z.close() 72 | zip_content = s.getvalue() 73 | 74 | # Create Lambda Function 75 | lambda_function = lambda_client.create_function( 76 | FunctionName=lambda_function_name, 77 | Runtime='python3.12', 78 | Timeout=60, 79 | Role=lambda_iam_role['Role']['Arn'], 80 | Code={'ZipFile': zip_content}, 81 | Handler='lambda_function.lambda_handler' 82 | ) 83 | return lambda_function 84 | 85 | 86 | def create_lambda_role(agent_name, dynamodb_table_name): 87 | lambda_function_role = f'{agent_name}-lambda-role' 88 | dynamodb_access_policy_name = f'{agent_name}-dynamodb-policy' 89 | # Create IAM Role for the Lambda function 90 | try: 91 | assume_role_policy_document = { 92 | "Version": "2012-10-17", 93 | "Statement": [ 94 | { 95 | "Effect": "Allow", 96 | "Principal": { 97 | "Service": "lambda.amazonaws.com" 98 | }, 99 | "Action": "sts:AssumeRole" 100 | } 101 | ] 102 | } 103 | 104 | assume_role_policy_document_json = json.dumps(assume_role_policy_document) 105 | 106 | lambda_iam_role = iam_client.create_role( 107 | RoleName=lambda_function_role, 108 | AssumeRolePolicyDocument=assume_role_policy_document_json 109 | ) 110 | 111 | # Pause to make sure role is created 112 | time.sleep(10) 113 | except iam_client.exceptions.EntityAlreadyExistsException: 114 | lambda_iam_role = iam_client.get_role(RoleName=lambda_function_role) 115 | 116 | # Attach the AWSLambdaBasicExecutionRole policy 117 | iam_client.attach_role_policy( 118 | RoleName=lambda_function_role, 119 | PolicyArn='arn:aws:iam::aws:policy/service-role/AWSLambdaBasicExecutionRole' 120 | ) 121 | 122 | # Create a policy to grant access to the DynamoDB table 123 | dynamodb_access_policy = { 124 | "Version": "2012-10-17", 125 | "Statement": [ 126 | { 127 | "Effect": "Allow", 128 | "Action": [ 129 | "dynamodb:GetItem", 130 | "dynamodb:PutItem", 131 | "dynamodb:DeleteItem" 132 | ], 133 | "Resource": "arn:aws:dynamodb:{}:{}:table/{}".format( 134 | region, account_id, dynamodb_table_name 135 | ) 136 | } 137 | ] 138 | } 139 | 140 | # Create the policy 141 | dynamodb_access_policy_json = json.dumps(dynamodb_access_policy) 142 | dynamodb_access_policy_response = iam_client.create_policy( 143 | PolicyName=dynamodb_access_policy_name, 144 | PolicyDocument=dynamodb_access_policy_json 145 | ) 146 | 147 | # Attach the policy to the Lambda function's role 148 | iam_client.attach_role_policy( 149 | RoleName=lambda_function_role, 150 | PolicyArn=dynamodb_access_policy_response['Policy']['Arn'] 151 | ) 152 | return lambda_iam_role 153 | 154 | 155 | def create_agent_role_and_policies(agent_name, agent_foundation_model, kb_id=None): 156 | agent_bedrock_allow_policy_name = f"{agent_name}-ba" 157 | agent_role_name = f'AmazonBedrockExecutionRoleForAgents_{agent_name}' 158 | # Create IAM policies for agent 159 | statements = [ 160 | { 161 | "Sid": "AmazonBedrockAgentBedrockFoundationModelPolicy", 162 | "Effect": "Allow", 163 | "Action": "bedrock:InvokeModel", 164 | "Resource": [ 165 | f"arn:aws:bedrock:{region}::foundation-model/{agent_foundation_model}" 166 | ] 167 | } 168 | ] 169 | # add Knowledge Base retrieve and retrieve and generate permissions if agent has KB attached to it 170 | if kb_id: 171 | statements.append( 172 | { 173 | "Sid": "QueryKB", 174 | "Effect": "Allow", 175 | "Action": [ 176 | "bedrock:Retrieve", 177 | "bedrock:RetrieveAndGenerate" 178 | ], 179 | "Resource": [ 180 | f"arn:aws:bedrock:{region}:{account_id}:knowledge-base/{kb_id}" 181 | ] 182 | } 183 | ) 184 | 185 | bedrock_agent_bedrock_allow_policy_statement = { 186 | "Version": "2012-10-17", 187 | "Statement": statements 188 | } 189 | 190 | bedrock_policy_json = json.dumps(bedrock_agent_bedrock_allow_policy_statement) 191 | 192 | agent_bedrock_policy = iam_client.create_policy( 193 | PolicyName=agent_bedrock_allow_policy_name, 194 | PolicyDocument=bedrock_policy_json 195 | ) 196 | 197 | # Create IAM Role for the agent and attach IAM policies 198 | assume_role_policy_document = { 199 | "Version": "2012-10-17", 200 | "Statement": [{ 201 | "Effect": "Allow", 202 | "Principal": { 203 | "Service": "bedrock.amazonaws.com" 204 | }, 205 | "Action": "sts:AssumeRole" 206 | }] 207 | } 208 | 209 | assume_role_policy_document_json = json.dumps(assume_role_policy_document) 210 | agent_role = iam_client.create_role( 211 | RoleName=agent_role_name, 212 | AssumeRolePolicyDocument=assume_role_policy_document_json 213 | ) 214 | 215 | # Pause to make sure role is created 216 | time.sleep(10) 217 | 218 | iam_client.attach_role_policy( 219 | RoleName=agent_role_name, 220 | PolicyArn=agent_bedrock_policy['Policy']['Arn'] 221 | ) 222 | return agent_role 223 | 224 | 225 | def delete_agent_roles_and_policies(agent_name): 226 | agent_bedrock_allow_policy_name = f"{agent_name}-ba" 227 | agent_role_name = f'AmazonBedrockExecutionRoleForAgents_{agent_name}' 228 | dynamodb_access_policy_name = f'{agent_name}-dynamodb-policy' 229 | lambda_function_role = f'{agent_name}-lambda-role' 230 | 231 | for policy in [agent_bedrock_allow_policy_name]: 232 | try: 233 | iam_client.detach_role_policy( 234 | RoleName=agent_role_name, 235 | PolicyArn=f'arn:aws:iam::{account_id}:policy/{policy}' 236 | ) 237 | except Exception as e: 238 | print(f"Could not detach {policy} from {agent_role_name}") 239 | print(e) 240 | 241 | for policy in [dynamodb_access_policy_name]: 242 | try: 243 | iam_client.detach_role_policy( 244 | RoleName=lambda_function_role, 245 | PolicyArn=f'arn:aws:iam::{account_id}:policy/{policy}' 246 | ) 247 | except Exception as e: 248 | print(f"Could not detach {policy} from {lambda_function_role}") 249 | print(e) 250 | 251 | try: 252 | iam_client.detach_role_policy( 253 | RoleName=lambda_function_role, 254 | PolicyArn='arn:aws:iam::aws:policy/service-role/AWSLambdaBasicExecutionRole' 255 | ) 256 | except Exception as e: 257 | print(f"Could not detach AWSLambdaBasicExecutionRole from {lambda_function_role}") 258 | print(e) 259 | 260 | for role_name in [agent_role_name, lambda_function_role]: 261 | try: 262 | iam_client.delete_role( 263 | RoleName=role_name 264 | ) 265 | except Exception as e: 266 | print(f"Could not delete role {role_name}") 267 | print(e) 268 | 269 | for policy in [agent_bedrock_allow_policy_name, dynamodb_access_policy_name]: 270 | try: 271 | iam_client.delete_policy( 272 | PolicyArn=f'arn:aws:iam::{account_id}:policy/{policy}' 273 | ) 274 | except Exception as e: 275 | print(f"Could not delete policy {policy}") 276 | print(e) 277 | 278 | 279 | def clean_up_resources( 280 | table_name, lambda_function, lambda_function_name, agent_action_group_response, agent_functions, 281 | agent_id, kb_id, alias_id 282 | ): 283 | action_group_id = agent_action_group_response['agentActionGroup']['actionGroupId'] 284 | action_group_name = agent_action_group_response['agentActionGroup']['actionGroupName'] 285 | # Delete Agent Action Group, Agent Alias, and Agent 286 | try: 287 | bedrock_agent_client.update_agent_action_group( 288 | agentId=agent_id, 289 | agentVersion='DRAFT', 290 | actionGroupId= action_group_id, 291 | actionGroupName=action_group_name, 292 | actionGroupExecutor={ 293 | 'lambda': lambda_function['FunctionArn'] 294 | }, 295 | functionSchema={ 296 | 'functions': agent_functions 297 | }, 298 | actionGroupState='DISABLED', 299 | ) 300 | bedrock_agent_client.disassociate_agent_knowledge_base( 301 | agentId=agent_id, 302 | agentVersion='DRAFT', 303 | knowledgeBaseId=kb_id 304 | ) 305 | bedrock_agent_client.delete_agent_action_group( 306 | agentId=agent_id, 307 | agentVersion='DRAFT', 308 | actionGroupId=action_group_id 309 | ) 310 | bedrock_agent_client.delete_agent_alias( 311 | agentAliasId=alias_id, 312 | agentId=agent_id 313 | ) 314 | bedrock_agent_client.delete_agent(agentId=agent_id) 315 | print(f"Agent {agent_id}, Agent Alias {alias_id}, and Action Group have been deleted.") 316 | except Exception as e: 317 | print(f"Error deleting Agent resources: {e}") 318 | 319 | # Delete Lambda function 320 | try: 321 | lambda_client.delete_function(FunctionName=lambda_function_name) 322 | print(f"Lambda function {lambda_function_name} has been deleted.") 323 | except Exception as e: 324 | print(f"Error deleting Lambda function {lambda_function_name}: {e}") 325 | 326 | # Delete DynamoDB table 327 | try: 328 | dynamodb_client.delete_table(TableName=table_name) 329 | print(f"Table {table_name} is being deleted...") 330 | waiter = dynamodb_client.get_waiter('table_not_exists') 331 | waiter.wait(TableName=table_name) 332 | print(f"Table {table_name} has been deleted.") 333 | except Exception as e: 334 | print(f"Error deleting table {table_name}: {e}") -------------------------------------------------------------------------------- /Reasoning-through-incentives/README.md: -------------------------------------------------------------------------------- 1 | # GVPO vs. GRPO: Novelties, Advantages, and Mathematical Differences 2 | 3 | ## Executive Summary 4 | 5 | **GVPO (Group Variance Policy Optimization)** addresses the training instability issues of GRPO while providing stronger theoretical guarantees. The key innovation is incorporating the **analytical solution to KL-constrained reward maximization directly into gradient weights** through a clever **zero-sum weight constraint** that eliminates the intractable partition function. 6 | 7 | Originla paper: https://arxiv.org/pdf/2504.19599 8 | 9 | --- 10 | 11 | ## 1. Core Novelties of GVPO 12 | 13 | ### **Novel 1: Zero-Sum Weight Constraint Eliminates Partition Function** 14 | 15 | **The Problem:** 16 | The optimal policy for KL-constrained reward maximization has a closed-form solution: 17 | 18 | $$\pi^*(y|x) = \frac{1}{Z(x)}\pi_{\theta'}(y|x)e^{R(x,y)/\beta}$$ 19 | 20 | where $Z(x) = \sum_y \pi_{\theta'}(y|x)e^{R(x,y)/\beta}$ is computationally intractable (requires summing over all possible responses). 21 | 22 | **GVPO's Solution:** 23 | By designing weights such that $\sum_{i=1}^k w_i = 0$, the partition function $\beta \log Z(x)$ becomes **invariant across responses and cancels out** in gradient computations: 24 | 25 | $$\nabla_\theta L(\theta) = -\sum_{x,\{y_i\}} \sum_{i=1}^k w_i \nabla_\theta \log \frac{\pi_\theta(y_i|x)}{\pi_{\theta'}(y_i|x)} = -\sum_{x,\{y_i\}} \sum_{i=1}^k w_i \nabla_\theta \frac{R_\theta(x,y_i)}{\beta}$$ 26 | 27 | Since $\sum w_i = 0$, the $\beta \log Z(x)$ term disappears, making the method computationally tractable. 28 | 29 | ### **Novel 2: Gradient Weights Based on Central Distance Differences** 30 | 31 | **GVPO's Weight Design:** 32 | 33 | $$w_i = (R(x, y_i) - \bar{R}(x)) - \beta\left(\log\frac{\pi_\theta(y_i|x)}{\pi_{\theta'}(y_i|x)} - \overline{\log\frac{\pi_\theta}{\pi_{\theta'}}}\right)$$ 34 | 35 | where the bar notation denotes group average: $\bar{R}(x) = \frac{1}{k}\sum_{i=1}^k R(x, y_i)$. 36 | 37 | **Physical Interpretation:** 38 | The weight is the **difference between actual reward central distance and implicit reward central distance**. 39 | 40 | ### **Novel 3: Three Equivalent Loss Interpretations** 41 | 42 | The paper elegantly shows GVPO's loss has three mathematically equivalent forms: 43 | 44 | #### **(a) Negative Log-Likelihood View :** 45 | 46 | $$\mathcal{L}_{\text{GVPO}}(\theta) = -\beta\sum_{x,\{y_i\}} \sum_{i=1}^k \left[(R(x,y_i) - \bar{R}) - \beta\left(\log\frac{\pi_\theta(y_i|x)}{\pi_{\theta'}(y_i|x)} - \overline{\log\frac{\pi_\theta}{\pi_{\theta'}}}\right)\right] \log \pi_\theta(y_i|x)$$ 47 | 48 | #### **(b) Mean Squared Error View :** 49 | 50 | $$\nabla_\theta \mathcal{L}_{\text{GVPO}} = \frac{1}{2}\nabla_\theta \sum_{x,\{y_i\}} \sum_{i=1}^k \left[(R_\theta(x,y_i) - \bar{R}_\theta) - (R(x,y_i) - \bar{R})\right]^2$$ 51 | 52 | **Key Insight:** Minimizing GVPO loss = minimizing **MSE between implicit and actual reward central distances**. 53 | 54 | #### **(c) Reinforcement Learning View (β=1):** 55 | 56 | $$\nabla_\theta \hat{\mathcal{L}}_{\text{GVPO}} = -2\mathbb{E}_{x,y}\left[(R(x,y) - \mathbb{E}_y R) \log \pi_\theta(y|x) + \text{Cov}(\log \pi_\theta, \log \pi_{\theta'}) - 0.5\text{Var}(\log \pi_\theta)\right]$$ 57 | 58 | Three components: 59 | 1. **Group-relative reward term**: Advantage maximization 60 | 2. **Covariance term**: Regularization preventing deviation from reference policy 61 | 3. **Variance term**: Entropy-like exploration encouragement 62 | 63 | --- 64 | 65 | ## 2. Mathematical Comparison: GVPO vs GRPO 66 | 67 | ### **GRPO Loss (Equation 2):** 68 | 69 | $$\mathcal{L}_{\text{GRPO}}(\theta) = -\sum_{x,y_1,\ldots,y_k} \sum_{i=1}^k \frac{R(x,y_i) - \text{Mean}(\{R(x,y_i)\})}{\text{Std}(\{R(x,y_i)\})} \log \pi_\theta(y_i|x)$$ 70 | 71 | **Key Differences:** 72 | 73 | | Aspect | GRPO | GVPO | 74 | |--------|------|------| 75 | | **Weight Formula** | $w_i = \frac{R(x,y_i) - \bar{R}}{\sigma_R}$ (standardized reward) | $w_i = (R(x,y_i) - \bar{R}) - \beta(\log\frac{\pi_\theta}{\pi_{\theta'}} - \overline{\log\frac{\pi_\theta}{\pi_{\theta'}}})$ | 76 | | **Normalization** | Divides by standard deviation $\sigma_R$ | No std normalization (only centering) | 77 | | **Policy Dependency** | Weights independent of current policy | Weights depend on $\pi_\theta/\pi_{\theta'}$ ratio | 78 | | **KL Constraint** | Applied externally (hyperparameter tuning) | **Built into gradient weights analytically** | 79 | | **Zero-Sum Property** | Yes (due to centering) | Yes (by design) | 80 | 81 | ### **Critical Mathematical Insight:** 82 | 83 | GRPO's standardization **conflates prompt-level difficulty with reward signals** (cited in paper [17]). For example: 84 | - Hard prompt with rewards [8, 9, 10] → all responses get similar standardized scores 85 | - Easy prompt with rewards [1, 2, 9] → large standardized score differences 86 | 87 | GVPO **removes std normalization** but adds the $\beta(\log\pi_\theta/\pi_{\theta'})$ term to directly encode the optimal policy structure. 88 | 89 | --- 90 | 91 | ## 3. Theoretical Advantages of GVPO 92 | 93 | ### **Advantage 1: Unique Optimal Solution ** 94 | 95 | **GVPO Guarantee:** 96 | 97 | $$\text{argmin}_\theta \hat{\mathcal{L}}_{\text{GVPO}}(\theta) = \pi^*(y|x) = \frac{1}{Z(x)}\pi_{\theta'}(y|x)e^{R(x,y)/\beta}$$ 98 | 99 | **Uniqueness** is proven by showing: 100 | 1. When $\pi_\theta = \pi^*$, the loss equals 0 (minimum achieved) 101 | 2. Any other policy yields loss > 0 (contradiction proof in Appendix B.1) 102 | 103 | **Why This Matters:** 104 | - **DPO fails this**: Due to Bradley-Terry model limitations [3, 11], DPO may converge to suboptimal policies 105 | - **GRPO lacks this**: No theoretical guarantee of convergence to KL-constrained optimum 106 | 107 | ### **Advantage 2: Flexible Sampling Distributions ** 108 | 109 | **GVPO's Condition:** 110 | Theorem 3.1 holds for **any sampling distribution $\pi_s$** satisfying: 111 | $$\forall x, \{y|\pi_{\theta'}(y|x) > 0\} \subseteq \{y|\pi_s(y|x) > 0\}$$ 112 | 113 | **Translation:** As long as $\pi_s$ covers all responses that the reference policy could generate, GVPO maintains theoretical guarantees. 114 | 115 | **Comparison with GRPO/PPO:** 116 | 117 | | Method | Sampling Requirement | Problem | 118 | |--------|---------------------|---------| 119 | | **PPO** | On-policy ($\pi_s = \pi_\theta$) | Low sample efficiency, requires fresh trajectories | 120 | | **GRPO** | Uses importance sampling $\frac{\pi_\theta}{\pi_{\theta_{\text{old}}}}$ | Gradient explosion when policies diverge; requires clipping (introduces bias) | 121 | | **GVPO** | Any $\pi_s$ satisfying mild condition | **No importance sampling, no explosion risk** | 122 | 123 | **Mathematical Detail:** 124 | Policy gradient methods require: 125 | 126 | $$\nabla_\theta[\mathbb{E}_{x,y\sim\pi_\theta}[R(x,y)] - \text{DKL}[\pi_\theta||\pi_{\theta_{\text{old}}}]] = \mathbb{E}_{x,y\sim\pi_\theta}\left[\left(R - \log\frac{\pi_\theta}{\pi_{\theta_{\text{old}}}} - 1\right)\nabla_\theta \log \pi_\theta\right]$$ 127 | 128 | Off-policy estimation uses importance sampling (Equation 16): 129 | 130 | $$\mathbb{E}_{x,y\sim\pi_{\theta_{\text{old}}}}\left[\frac{\pi_\theta(y|x)}{\pi_{\theta_{\text{old}}}(y|x)} \left(R - \log\frac{\pi_\theta}{\pi_{\theta_{\text{old}}}} - 1\right)\nabla_\theta \log \pi_\theta\right]$$ 131 | 132 | The ratio $\frac{\pi_\theta}{\pi_{\theta_{\text{old}}}}$ can explode → gradient clipping needed. 133 | 134 | **GVPO's Advantage:** 135 | By using the zero-sum property and central distances, GVPO's gradient becomes: 136 | 137 | $$\mathbb{E}_{x,y\sim\pi_s}\left[\left(R - \log\frac{\pi_\theta}{\pi_{\theta'}} - \mathbb{E}_{y\sim\pi_s}\left(R - \log\frac{\pi_\theta}{\pi_{\theta'}}\right)\right)\nabla_\theta \log \pi_\theta\right]$$ 138 | 139 | **No importance sampling ratio** appears in the gradient! 140 | 141 | ### **Advantage 3: Unbiased and Consistent Estimator ** 142 | 143 | The empirical loss with finite samples is: 144 | 145 | $$\frac{1}{|D|}\sum_{(x,\{y_i\})\in D} \frac{1}{k-1} \sum_{i=1}^k \left[(R_\theta(x,y_i) - \bar{R}_\theta) - (R(x,y_i) - \bar{R})\right]^2$$ 146 | 147 | **Note the $\frac{1}{k-1}$ factor** (not $\frac{1}{k}$) — this is the **Bessel correction** for unbiased variance estimation. 148 | 149 | **Why This Matters:** 150 | - With small $k$ (few samples per prompt), bias becomes significant 151 | - Corollary 3.5 extends this to **variable $k(x)$ per prompt**, enabling mixed-source datasets 152 | 153 | --- 154 | 155 | ## 4. Algorithm Comparison 156 | 157 | ### **Algorithm 1 (GVPO) vs GRPO:** 158 | 159 | ``` 160 | GVPO: 161 | 1. Sample k responses {yi} ~ πs(·|x) 162 | 2. Compute weights: wi = (R(x,yi) - R̄) - β(log(πθ/πθ') - log(πθ/πθ')̄) 163 | 3. Update: minimize -Σ wi log πθ(yi|x) 164 | 165 | GRPO: 166 | 1. Sample k responses {yi} ~ πθ_old(·|x) 167 | 2. Compute weights: wi = (R(x,yi) - R̄) / σR 168 | 3. Update: minimize -Σ wi log πθ(yi|x) 169 | 4. Apply gradient clipping + KL penalty 170 | ``` 171 | 172 | **Key Implementation Difference (Listing 1):** 173 | 174 | GVPO only changes GRPO's loss computation by: 175 | ```python 176 | # GRPO: 177 | advs = (R - R.mean()) / R.std() # Standardization 178 | loss = -scores * advs 179 | 180 | # GVPO: 181 | advs = (R - R.mean()) - beta * ((scores_new - scores_new.mean()) 182 | - (scores_old - scores_old.mean())) 183 | loss = -beta * scores * advs / (k-1) # Note: k-1 for unbiased estimator 184 | ``` 185 | 186 | --- 187 | 188 | ## 5. Empirical Performance 189 | 190 | | Model | AIME2024 | AMC | MATH500 | Minerva | OlympiadBench | 191 | |-------|----------|-----|---------|---------|---------------| 192 | | Base (Qwen2.5-Math-7B) | 14.68 | 38.55 | 64.00 | 27.20 | 30.66 | 193 | | +GRPO | 14.79 | 55.42 | **80.00** | 41.17 | 42.07 | 194 | | +Dr.GRPO | 16.56 | 48.19 | 81.20 | 44.48 | 43.40 | 195 | | +**GVPO** | **20.72** | **62.65** | **83.80** | **45.95** | **46.96** | 196 | 197 | **Observations:** 198 | - GVPO achieves **best performance across all 5 benchmarks** 199 | - **40% relative improvement** on AIME2024 over GRPO (14.79 → 20.72) 200 | - Particularly strong on complex reasoning tasks (AIME, OlympiadBench) 201 | 202 | ### **Ablation Study Insights:** 203 | 204 | **Figure 2 (β sensitivity):** 205 | - GVPO shows **little performance fluctuation** across β ∈ [0.01, 0.5] 206 | - Suggests **robustness to hyperparameter tuning** (unlike GRPO's high sensitivity) 207 | 208 | **Figure 3 (Scaling with k):** 209 | - GVPO **consistently outperforms GRPO** for all k ∈ [2, 32] 210 | - **Superior scalability**: GVPO on 1.5B model with k=32 matches 7B model performance 211 | - **Inference cost reduction**: Can use smaller models with more samples 212 | 213 | **Figure 4 (Off-policy sampling πs):** 214 | - Tests mixing historical responses with current policy samples 215 | - GVPO maintains **robust performance** with ratios from 0:8 to 4:4 (historical:current) 216 | - Validates Corollary 3.2's theoretical guarantee 217 | 218 | --- 219 | 220 | ## 6. Limitations and Future Work 221 | 222 | ### **Acknowledged Limitations:** 223 | 1. **Computational cost**: Still requires sampling k responses per prompt 224 | 2. **Reward model quality**: Performance depends on accurate R(x,y) 225 | 3. **Hyperparameter β**: Though robust, still requires selection 226 | 227 | ### **Unexplored Connections:** 228 | - Integration with exploration strategies from classical RL 229 | - Extension to continuous action spaces 230 | - Multi-modal reward signals 231 | 232 | --- 233 | 234 | ## 7. Summary: Why GVPO is Better 235 | 236 | | Criterion | GRPO | GVPO | 237 | |-----------|------|------| 238 | | **Training Stability** | ❌ Documented instability [34, 16] | ✅ Implicit regularization via Cov/Var terms | 239 | | **Hyperparameter Sensitivity** | ❌ High (clip threshold, KL coeff) | ✅ Robust to β variations | 240 | | **Theoretical Guarantee** | ❌ No convergence to optimal policy | ✅ Unique optimal = KL-constrained optimum | 241 | | **Sampling Flexibility** | ⚠️ Uses importance sampling | ✅ Any πs (no IS needed) | 242 | | **Normalization Bias** | ❌ Std normalization conflates difficulty | ✅ Only centering (no std division) | 243 | | **Gradient Explosion** | ❌ Requires clipping | ✅ No IS ratio → inherently stable | 244 | | **Performance** | Baseline | **Best across all benchmarks** | 245 | 246 | --- 247 | 248 | ## **Bottom Line** 249 | 250 | GVPO's core innovation is **operationalizing the closed-form optimal policy** through a mathematically elegant **zero-sum weight design** that: 251 | 1. Eliminates the intractable partition function 252 | 2. Embeds KL constraints directly into gradients 253 | 3. Enables off-policy training without importance sampling 254 | 4. Guarantees convergence to the unique optimal policy 255 | 256 | This makes GVPO a **theoretically principled AND empirically superior** alternative to GRPO for LLM post-training. -------------------------------------------------------------------------------- /Reasoning-through-incentives/gvpo.md: -------------------------------------------------------------------------------- 1 | # GVPO Paper Analysis: Novelties, Advantages, and Mathematical Differences 2 | 3 | ## Executive Summary 4 | 5 | **GVPO (Group Variance Policy Optimization)** addresses the training instability issues of GRPO while providing stronger theoretical guarantees. The key innovation is incorporating the **analytical solution to KL-constrained reward maximization directly into gradient weights** through a clever **zero-sum weight constraint** that eliminates the intractable partition function. 6 | 7 | --- 8 | 9 | ## 1. Core Novelties of GVPO 10 | 11 | ### **Novel 1: Zero-Sum Weight Constraint Eliminates Partition Function** 12 | 13 | **The Problem:** 14 | The optimal policy for KL-constrained reward maximization has a closed-form solution: 15 | 16 | $$\pi^*(y|x) = \frac{1}{Z(x)}\pi_{\theta'}(y|x)e^{R(x,y)/\beta}$$ 17 | 18 | where $Z(x) = \sum_y \pi_{\theta'}(y|x)e^{R(x,y)/\beta}$ is computationally intractable (requires summing over all possible responses). 19 | 20 | **GVPO's Solution:** 21 | By designing weights such that $\sum_{i=1}^k w_i = 0$, the partition function $\beta \log Z(x)$ becomes **invariant across responses and cancels out** in gradient computations: 22 | 23 | $$\nabla_\theta L(\theta) = -\sum_{x,\{y_i\}} \sum_{i=1}^k w_i \nabla_\theta \log \frac{\pi_\theta(y_i|x)}{\pi_{\theta'}(y_i|x)} = -\sum_{x,\{y_i\}} \sum_{i=1}^k w_i \nabla_\theta \frac{R_\theta(x,y_i)}{\beta}$$ 24 | 25 | Since $\sum w_i = 0$, the $\beta \log Z(x)$ term disappears, making the method computationally tractable. 26 | 27 | ### **Novel 2: Gradient Weights Based on Central Distance Differences** 28 | 29 | **GVPO's Weight Design:** 30 | 31 | $$w_i = (R(x, y_i) - \bar{R}(x)) - \beta\left(\log\frac{\pi_\theta(y_i|x)}{\pi_{\theta'}(y_i|x)} - \overline{\log\frac{\pi_\theta}{\pi_{\theta'}}}\right)$$ 32 | 33 | where the bar notation denotes group average: $\bar{R}(x) = \frac{1}{k}\sum_{i=1}^k R(x, y_i)$. 34 | 35 | **Physical Interpretation:** 36 | The weight is the **difference between actual reward central distance and implicit reward central distance**. 37 | 38 | ### **Novel 3: Three Equivalent Loss Interpretations** 39 | 40 | The paper elegantly shows GVPO's loss has three mathematically equivalent forms: 41 | 42 | #### **(a) Negative Log-Likelihood View (Equation 9):** 43 | 44 | $$\mathcal{L}_{\text{GVPO}}(\theta) = -\beta\sum_{x,\{y_i\}} \sum_{i=1}^k \left[(R(x,y_i) - \bar{R}) - \beta\left(\log\frac{\pi_\theta(y_i|x)}{\pi_{\theta'}(y_i|x)} - \overline{\log\frac{\pi_\theta}{\pi_{\theta'}}}\right)\right] \log \pi_\theta(y_i|x)$$ 45 | 46 | #### **(b) Mean Squared Error View (Middle panel, Figure 1):** 47 | 48 | $$\nabla_\theta \mathcal{L}_{\text{GVPO}} = \frac{1}{2}\nabla_\theta \sum_{x,\{y_i\}} \sum_{i=1}^k \left[(R_\theta(x,y_i) - \bar{R}_\theta) - (R(x,y_i) - \bar{R})\right]^2$$ 49 | 50 | **Key Insight:** Minimizing GVPO loss = minimizing **MSE between implicit and actual reward central distances**. 51 | 52 | #### **(c) Reinforcement Learning View (Equation 14, β=1):** 53 | 54 | $$\nabla_\theta \hat{\mathcal{L}}_{\text{GVPO}} = -2\mathbb{E}_{x,y}\left[(R(x,y) - \mathbb{E}_y R) \log \pi_\theta(y|x) + \text{Cov}(\log \pi_\theta, \log \pi_{\theta'}) - 0.5\text{Var}(\log \pi_\theta)\right]$$ 55 | 56 | Three components: 57 | 1. **Group-relative reward term**: Advantage maximization 58 | 2. **Covariance term**: Regularization preventing deviation from reference policy 59 | 3. **Variance term**: Entropy-like exploration encouragement 60 | 61 | --- 62 | 63 | ## 2. Mathematical Comparison: GVPO vs GRPO 64 | 65 | ### **GRPO Loss (Equation 2):** 66 | 67 | $$\mathcal{L}_{\text{GRPO}}(\theta) = -\sum_{x,y_1,\ldots,y_k} \sum_{i=1}^k \frac{R(x,y_i) - \text{Mean}(\{R(x,y_i)\})}{\text{Std}(\{R(x,y_i)\})} \log \pi_\theta(y_i|x)$$ 68 | 69 | **Key Differences:** 70 | 71 | | Aspect | GRPO | GVPO | 72 | |--------|------|------| 73 | | **Weight Formula** | $w_i = \frac{R(x,y_i) - \bar{R}}{\sigma_R}$ (standardized reward) | $w_i = (R(x,y_i) - \bar{R}) - \beta(\log\frac{\pi_\theta}{\pi_{\theta'}} - \overline{\log\frac{\pi_\theta}{\pi_{\theta'}}})$ | 74 | | **Normalization** | Divides by standard deviation $\sigma_R$ | No std normalization (only centering) | 75 | | **Policy Dependency** | Weights independent of current policy | Weights depend on $\pi_\theta/\pi_{\theta'}$ ratio | 76 | | **KL Constraint** | Applied externally (hyperparameter tuning) | **Built into gradient weights analytically** | 77 | | **Zero-Sum Property** | Yes (due to centering) | Yes (by design) | 78 | 79 | ### **Critical Mathematical Insight:** 80 | 81 | GRPO's standardization **conflates prompt-level difficulty with reward signals** (cited in paper [17]). For example: 82 | - Hard prompt with rewards [8, 9, 10] → all responses get similar standardized scores 83 | - Easy prompt with rewards [1, 2, 9] → large standardized score differences 84 | 85 | GVPO **removes std normalization** but adds the $\beta(\log\pi_\theta/\pi_{\theta'})$ term to directly encode the optimal policy structure. 86 | 87 | --- 88 | 89 | ## 3. Theoretical Advantages of GVPO 90 | 91 | ### **Advantage 1: Unique Optimal Solution (Theorem 3.1)** 92 | 93 | **GVPO Guarantee:** 94 | $$\text{argmin}_\theta \hat{\mathcal{L}}_{\text{GVPO}}(\theta) = \pi^*(y|x) = \frac{1}{Z(x)}\pi_{\theta'}(y|x)e^{R(x,y)/\beta}$$ 95 | 96 | **Uniqueness** is proven by showing: 97 | 1. When $\pi_\theta = \pi^*$, the loss equals 0 (minimum achieved) 98 | 2. Any other policy yields loss > 0 (contradiction proof in Appendix B.1) 99 | 100 | **Why This Matters:** 101 | - **DPO fails this**: Due to Bradley-Terry model limitations [3, 11], DPO may converge to suboptimal policies 102 | - **GRPO lacks this**: No theoretical guarantee of convergence to KL-constrained optimum 103 | 104 | ### **Advantage 2: Flexible Sampling Distributions (Corollary 3.2)** 105 | 106 | **GVPO's Condition:** 107 | Theorem 3.1 holds for **any sampling distribution $\pi_s$** satisfying: 108 | $$\forall x, \{y|\pi_{\theta'}(y|x) > 0\} \subseteq \{y|\pi_s(y|x) > 0\}$$ 109 | 110 | **Translation:** As long as $\pi_s$ covers all responses that the reference policy could generate, GVPO maintains theoretical guarantees. 111 | 112 | **Comparison with GRPO/PPO:** 113 | 114 | | Method | Sampling Requirement | Problem | 115 | |--------|---------------------|---------| 116 | | **PPO** | On-policy ($\pi_s = \pi_\theta$) | Low sample efficiency, requires fresh trajectories | 117 | | **GRPO** | Uses importance sampling $\frac{\pi_\theta}{\pi_{\theta_{\text{old}}}}$ | Gradient explosion when policies diverge; requires clipping (introduces bias) | 118 | | **GVPO** | Any $\pi_s$ satisfying mild condition | **No importance sampling, no explosion risk** | 119 | 120 | **Mathematical Detail:** 121 | Policy gradient methods require: 122 | 123 | $$\nabla_\theta[\mathbb{E}_{x,y\sim\pi_\theta}[R(x,y)] - \text{DKL}[\pi_\theta||\pi_{\theta_{\text{old}}}]] = \mathbb{E}_{x,y\sim\pi_\theta}\left[\left(R - \log\frac{\pi_\theta}{\pi_{\theta_{\text{old}}}} - 1\right)\nabla_\theta \log \pi_\theta\right]$$ 124 | 125 | Off-policy estimation uses importance sampling (Equation 16): 126 | 127 | $$\mathbb{E}_{x,y\sim\pi_{\theta_{\text{old}}}}\left[\frac{\pi_\theta(y|x)}{\pi_{\theta_{\text{old}}}(y|x)} \left(R - \log\frac{\pi_\theta}{\pi_{\theta_{\text{old}}}} - 1\right)\nabla_\theta \log \pi_\theta\right]$$ 128 | 129 | The ratio $\frac{\pi_\theta}{\pi_{\theta_{\text{old}}}}$ can explode → gradient clipping needed. 130 | 131 | **GVPO's Advantage:** 132 | By using the zero-sum property and central distances, GVPO's gradient becomes: 133 | 134 | $$\mathbb{E}_{x,y\sim\pi_s}\left[\left(R - \log\frac{\pi_\theta}{\pi_{\theta'}} - \mathbb{E}_{y\sim\pi_s}\left(R - \log\frac{\pi_\theta}{\pi_{\theta'}}\right)\right)\nabla_\theta \log \pi_\theta\right]$$ 135 | 136 | **No importance sampling ratio** appears in the gradient! 137 | 138 | ### **Advantage 3: Unbiased and Consistent Estimator (Theorem 3.4)** 139 | 140 | The empirical loss with finite samples is: 141 | 142 | $$\frac{1}{|D|}\sum_{(x,\{y_i\})\in D} \frac{1}{k-1} \sum_{i=1}^k \left[(R_\theta(x,y_i) - \bar{R}_\theta) - (R(x,y_i) - \bar{R})\right]^2$$ 143 | 144 | **Note the $\frac{1}{k-1}$ factor** (not $\frac{1}{k}$) — this is the **Bessel correction** for unbiased variance estimation. 145 | 146 | **Why This Matters:** 147 | - With small $k$ (few samples per prompt), bias becomes significant 148 | - Corollary 3.5 extends this to **variable $k(x)$ per prompt**, enabling mixed-source datasets 149 | 150 | --- 151 | 152 | ## 4. Algorithm Comparison 153 | 154 | ### **Algorithm 1 (GVPO) vs GRPO:** 155 | 156 | ``` 157 | GVPO: 158 | 1. Sample k responses {yi} ~ πs(·|x) 159 | 2. Compute weights: wi = (R(x,yi) - R̄) - β(log(πθ/πθ') - log(πθ/πθ')̄) 160 | 3. Update: minimize -Σ wi log πθ(yi|x) 161 | 162 | GRPO: 163 | 1. Sample k responses {yi} ~ πθ_old(·|x) 164 | 2. Compute weights: wi = (R(x,yi) - R̄) / σR 165 | 3. Update: minimize -Σ wi log πθ(yi|x) 166 | 4. Apply gradient clipping + KL penalty 167 | ``` 168 | 169 | **Key Implementation Difference (Listing 1):** 170 | 171 | GVPO only changes GRPO's loss computation by: 172 | ```python 173 | # GRPO: 174 | advs = (R - R.mean()) / R.std() # Standardization 175 | loss = -scores * advs 176 | 177 | # GVPO: 178 | advs = (R - R.mean()) - beta * ((scores_new - scores_new.mean()) 179 | - (scores_old - scores_old.mean())) 180 | loss = -beta * scores * advs / (k-1) # Note: k-1 for unbiased estimator 181 | ``` 182 | 183 | --- 184 | 185 | ## 5. Empirical Performance (Table 1) 186 | 187 | | Model | AIME2024 | AMC | MATH500 | Minerva | OlympiadBench | 188 | |-------|----------|-----|---------|---------|---------------| 189 | | Base (Qwen2.5-Math-7B) | 14.68 | 38.55 | 64.00 | 27.20 | 30.66 | 190 | | +GRPO | 14.79 | 55.42 | **80.00** | 41.17 | 42.07 | 191 | | +Dr.GRPO | 16.56 | 48.19 | 81.20 | 44.48 | 43.40 | 192 | | +**GVPO** | **20.72** | **62.65** | **83.80** | **45.95** | **46.96** | 193 | 194 | **Observations:** 195 | - GVPO achieves **best performance across all 5 benchmarks** 196 | - **40% relative improvement** on AIME2024 over GRPO (14.79 → 20.72) 197 | - Particularly strong on complex reasoning tasks (AIME, OlympiadBench) 198 | 199 | ### **Ablation Study Insights:** 200 | 201 | **Figure 2 (β sensitivity):** 202 | - GVPO shows **little performance fluctuation** across β ∈ [0.01, 0.5] 203 | - Suggests **robustness to hyperparameter tuning** (unlike GRPO's high sensitivity) 204 | 205 | **Figure 3 (Scaling with k):** 206 | - GVPO **consistently outperforms GRPO** for all k ∈ [2, 32] 207 | - **Superior scalability**: GVPO on 1.5B model with k=32 matches 7B model performance 208 | - **Inference cost reduction**: Can use smaller models with more samples 209 | 210 | **Figure 4 (Off-policy sampling πs):** 211 | - Tests mixing historical responses with current policy samples 212 | - GVPO maintains **robust performance** with ratios from 0:8 to 4:4 (historical:current) 213 | - Validates Corollary 3.2's theoretical guarantee 214 | 215 | --- 216 | 217 | ## 6. Limitations and Future Work 218 | 219 | ### **Acknowledged Limitations:** 220 | 1. **Computational cost**: Still requires sampling k responses per prompt 221 | 2. **Reward model quality**: Performance depends on accurate R(x,y) 222 | 3. **Hyperparameter β**: Though robust, still requires selection 223 | 224 | ### **Unexplored Connections:** 225 | - Integration with exploration strategies from classical RL 226 | - Extension to continuous action spaces 227 | - Multi-modal reward signals 228 | 229 | --- 230 | 231 | ## 7. Summary: Why GVPO is Better 232 | 233 | | Criterion | GRPO | GVPO | 234 | |-----------|------|------| 235 | | **Training Stability** | ❌ Documented instability [34, 16] | ✅ Implicit regularization via Cov/Var terms | 236 | | **Hyperparameter Sensitivity** | ❌ High (clip threshold, KL coeff) | ✅ Robust to β variations | 237 | | **Theoretical Guarantee** | ❌ No convergence to optimal policy | ✅ Unique optimal = KL-constrained optimum | 238 | | **Sampling Flexibility** | ⚠️ Uses importance sampling | ✅ Any πs (no IS needed) | 239 | | **Normalization Bias** | ❌ Std normalization conflates difficulty | ✅ Only centering (no std division) | 240 | | **Gradient Explosion** | ❌ Requires clipping | ✅ No IS ratio → inherently stable | 241 | | **Performance** | Baseline | **Best across all benchmarks** | 242 | 243 | --- 244 | 245 | ## **Bottom Line** 246 | 247 | GVPO's core innovation is **operationalizing the closed-form optimal policy** through a mathematically elegant **zero-sum weight design** that: 248 | 1. Eliminates the intractable partition function 249 | 2. Embeds KL constraints directly into gradients 250 | 3. Enables off-policy training without importance sampling 251 | 4. Guarantees convergence to the unique optimal policy 252 | 253 | This makes GVPO a **theoretically principled AND empirically superior** alternative to GRPO for LLM post-training. -------------------------------------------------------------------------------- /notebooks/multi_agent_collaboration.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "16df69d3-74d7-413b-82c7-f597f55eea78", 6 | "metadata": {}, 7 | "source": [ 8 | "# Multi-Agent Collaboration\n", 9 | "\n", 10 | "We will delve into the intricacies of agentic service and explore the strategies and best practices for cultivating effective multi-agent collaboration. We will examine how organizations can harness the collective intelligence of their agents, leverage emerging technologies to enhance coordination and communication, and navigate the challenges that often arise in complex, dynamic environments.\n", 11 | "\n", 12 | "\n", 13 | "
\n" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": null, 19 | "id": "898c86e5-10f3-4dc3-93ac-2488b8e69259", 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "#!pip install crewai --force-reinstall\n", 24 | "#!pip install --force-reinstall -v \"setuptools<70\"" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": null, 30 | "id": "e5b795ed-a9b7-445c-8487-fde81b71fd42", 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "import os, sys\n", 35 | "from crewai import Agent, Task, Crew, Process\n", 36 | "from dotenv import load_dotenv\n", 37 | "from langchain.tools import DuckDuckGoSearchRun\n", 38 | "module_paths = [\"./\", \"../scripts\"]\n", 39 | "for module_path in module_paths:\n", 40 | " sys.path.append(os.path.abspath(module_path))\n", 41 | "from bedrock import *\n", 42 | "from crewai_tools import (\n", 43 | " DirectoryReadTool,\n", 44 | " FileReadTool,\n", 45 | " SerperDevTool,\n", 46 | " WebsiteSearchTool\n", 47 | ")\n", 48 | "\n", 49 | "load_dotenv()\n", 50 | "search_tool = DuckDuckGoSearchRun()\n", 51 | "web_rag_tool = WebsiteSearchTool()" 52 | ] 53 | }, 54 | { 55 | "cell_type": "markdown", 56 | "id": "fe7427ba-19da-4d4a-b579-04b747eedef3", 57 | "metadata": {}, 58 | "source": [ 59 | "## Define agents" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": null, 65 | "id": "7b479cc6-f898-4bdd-a98a-eeb9b46529dc", 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "class blogAgents():\n", 70 | " def planner(self, topic, model_id):\n", 71 | " return Agent(\n", 72 | " role=\"Content Planner\",\n", 73 | " goal=f\"Plan engaging and factually accurate content on {topic}\",\n", 74 | " backstory=\"You're working on planning a blog article \"\n", 75 | " f\"about the topic: {topic}.\"\n", 76 | " \"You collect information by searhing the web for the latest developements that directly relate to the {topic}.\"\n", 77 | " \"audience learn something \"\n", 78 | " \"and make informed decisions. \"\n", 79 | " \"Your work is the basis for \"\n", 80 | " \"the Content Writer to write an article on this topic.\",\n", 81 | " allow_delegation=False,\n", 82 | " tools=[search_tool, web_rag_tool],\n", 83 | " llm=get_llm(model_id),\n", 84 | " verbose=True\n", 85 | " )\n", 86 | " \n", 87 | " def writer(self, topic, model_id):\n", 88 | " return Agent(\n", 89 | " role=\"Content Writer\",\n", 90 | " goal=f\"Write insightful and factually accurate opinion piece about the topic: {topic}\",\n", 91 | " backstory=\"You're working on a writing \"\n", 92 | " f\"a new opinion piece about the topic: {topic}. \"\n", 93 | " \"You base your writing on the work of \"\n", 94 | " \"the Content Planner, who provides an outline \"\n", 95 | " \"and relevant context about the topic. \"\n", 96 | " \"You follow the main objectives and \"\n", 97 | " \"direction of the outline, \"\n", 98 | " \"as provide by the Content Planner. \"\n", 99 | " \"You also provide objective and impartial insights \"\n", 100 | " \"and back them up with information \"\n", 101 | " \"provide by the Content Planner. \"\n", 102 | " \"You acknowledge in your opinion piece \"\n", 103 | " \"when your statements are opinions \"\n", 104 | " \"as opposed to objective statements.\",\n", 105 | " allow_delegation=False,\n", 106 | " llm=get_llm(model_id),\n", 107 | " verbose=True\n", 108 | " )\n", 109 | "\n", 110 | " def editor(self, model_id):\n", 111 | " return Agent(\n", 112 | " role=\"Editor\",\n", 113 | " goal=\"Edit a given blog post to align with \"\n", 114 | " \"the writing style of the organization. \",\n", 115 | " backstory=\"You are an editor who receives a blog post from the Content Writer. \"\n", 116 | " \"Your goal is to review the blog post to ensure that it follows journalistic best practices,\"\n", 117 | " \"provides balanced viewpoints when providing opinions or assertions, \"\n", 118 | " \"and also avoids major controversial topics or opinions when possible.\",\n", 119 | " allow_delegation=False,\n", 120 | " llm=get_llm(model_id),\n", 121 | " verbose=True\n", 122 | " )" 123 | ] 124 | }, 125 | { 126 | "cell_type": "markdown", 127 | "id": "f412e0e3-5ef8-4ed2-98c5-f44b8996cba0", 128 | "metadata": {}, 129 | "source": [ 130 | "## Define the tasks (plan, write and edit) for the agents we created above." 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": null, 136 | "id": "44bb42d5-9396-45aa-863e-34ef2a4921a4", 137 | "metadata": {}, 138 | "outputs": [], 139 | "source": [ 140 | "class blogTasks():\n", 141 | " def plan(self, planner, topic): \n", 142 | " return Task(\n", 143 | " description=(\n", 144 | " \"1. Prioritize the latest trends, key players, \"\n", 145 | " f\"and noteworthy news on {topic}.\\n\"\n", 146 | " \"2. Identify the target audience, considering \"\n", 147 | " \"their interests and pain points.\\n\"\n", 148 | " \"3. Develop a detailed content outline including \"\n", 149 | " \"an introduction, key points, and a call to action.\\n\"\n", 150 | " \"4. Include SEO keywords and relevant data or sources.\"\n", 151 | " ),\n", 152 | " expected_output=\"A comprehensive content plan document \"\n", 153 | " \"with an outline, audience analysis, \"\n", 154 | " \"SEO keywords, and resources.\",\n", 155 | " agent=planner,\n", 156 | " )\n", 157 | " def write(self, writer, topic): \n", 158 | " return Task(\n", 159 | " description=(\n", 160 | " \"1. Use the content plan to craft a compelling \"\n", 161 | " f\"blog post on {topic}.\\n\"\n", 162 | " \"2. Incorporate SEO keywords naturally.\\n\"\n", 163 | " \"3. Sections/Subtitles are properly named \"\n", 164 | " \"in an engaging manner.\\n\"\n", 165 | " \"4. Ensure the post is structured with an \"\n", 166 | " \"engaging introduction, insightful body, \"\n", 167 | " \"and a summarizing conclusion.\\n\"\n", 168 | " \"5. Proofread for grammatical errors and \"\n", 169 | " \"alignment with the brand's voice.\\n\"\n", 170 | " ),\n", 171 | " expected_output=\"A well-written blog post \"\n", 172 | " \"in markdown format, ready for publication, \"\n", 173 | " \"each section should have 2 or 3 paragraphs.\",\n", 174 | " agent=writer,\n", 175 | " )\n", 176 | " \n", 177 | " def edit(self, editor):\n", 178 | " return Task(\n", 179 | " description=(\"Proofread the given blog post for \"\n", 180 | " \"grammatical errors and \"\n", 181 | " \"alignment with the brand's voice.\"),\n", 182 | " expected_output=\"A well-written blog post in markdown format, \"\n", 183 | " \"ready for publication, \"\n", 184 | " \"each section should have 2 or 3 paragraphs.\",\n", 185 | " agent=editor\n", 186 | " )" 187 | ] 188 | }, 189 | { 190 | "cell_type": "markdown", 191 | "id": "77c9e235-2cbc-4d80-bc31-9c43ada63b3a", 192 | "metadata": {}, 193 | "source": [ 194 | "## It’s time to assemble the crew. Combine the agents into our awesome crew." 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": null, 200 | "id": "50b03f9f-23d6-4ea6-8f7a-6f8ee56e94db", 201 | "metadata": {}, 202 | "outputs": [], 203 | "source": [ 204 | "class blogCrew:\n", 205 | " def __init__(self, topic, model_id):\n", 206 | " self.topic = topic\n", 207 | " self.model_id = model_id\n", 208 | "\n", 209 | " def run(self):\n", 210 | " agents = blogAgents()\n", 211 | " tasks = blogTasks()\n", 212 | "\n", 213 | " planner_agent = agents.planner(self.topic, self.model_id)\n", 214 | " writer_agent = agents.writer(self.topic, self.model_id)\n", 215 | " editor_agent = agents.editor(self.model_id)\n", 216 | "\n", 217 | " plan_task = tasks.plan(planner_agent, self.topic)\n", 218 | " write_task = tasks.write(writer_agent, self.topic)\n", 219 | " edit_task = tasks.edit(editor_agent)\n", 220 | "\n", 221 | "\n", 222 | " crew = Crew(\n", 223 | " agents=[planner_agent, writer_agent, editor_agent],\n", 224 | " tasks=[plan_task, write_task, edit_task],\n", 225 | " verbose=True,\n", 226 | " memory=True,\n", 227 | " embedder={\n", 228 | " \"provider\": \"huggingface\",\n", 229 | " \"config\": {\"model\": \"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2\"},\n", 230 | " },\n", 231 | " cache=True,\n", 232 | " process=Process.sequential # Sequential process will have tasks executed one after the other and the outcome of the previous one is\n", 233 | " )\n", 234 | "\n", 235 | " result = crew.kickoff()\n", 236 | " return result" 237 | ] 238 | }, 239 | { 240 | "cell_type": "markdown", 241 | "id": "998a1a25-5b2d-400b-b8e0-00fd890979b3", 242 | "metadata": {}, 243 | "source": [ 244 | "## Write a blog post based on the input topic" 245 | ] 246 | }, 247 | { 248 | "cell_type": "code", 249 | "execution_count": null, 250 | "id": "4b83fe08-12ed-479c-aefc-2e2c959b33ef", 251 | "metadata": { 252 | "scrolled": true 253 | }, 254 | "outputs": [], 255 | "source": [ 256 | "topic = \"Write a release note for Amazon Q\"\n", 257 | "model_id = \"anthropic.claude-3-haiku-20240307-v1:0\"\n", 258 | "blog_crew = blogCrew(topic, model_id)\n", 259 | "result = blog_crew.run()" 260 | ] 261 | }, 262 | { 263 | "cell_type": "markdown", 264 | "id": "791ad4fe-7df5-4f4f-a8ff-e05512ca13d6", 265 | "metadata": {}, 266 | "source": [ 267 | "## (Optional) Display the final output in markdown format " 268 | ] 269 | }, 270 | { 271 | "cell_type": "code", 272 | "execution_count": null, 273 | "id": "9734e88e-36f2-4975-8158-a4f28070c227", 274 | "metadata": {}, 275 | "outputs": [], 276 | "source": [ 277 | "from IPython.display import Markdown\n", 278 | "Markdown(result.raw)" 279 | ] 280 | }, 281 | { 282 | "cell_type": "code", 283 | "execution_count": null, 284 | "id": "2c5546b5-6d94-46b4-969f-08efccfeab2e", 285 | "metadata": {}, 286 | "outputs": [], 287 | "source": [] 288 | } 289 | ], 290 | "metadata": { 291 | "kernelspec": { 292 | "display_name": "medf", 293 | "language": "python", 294 | "name": "medf" 295 | }, 296 | "language_info": { 297 | "codemirror_mode": { 298 | "name": "ipython", 299 | "version": 3 300 | }, 301 | "file_extension": ".py", 302 | "mimetype": "text/x-python", 303 | "name": "python", 304 | "nbconvert_exporter": "python", 305 | "pygments_lexer": "ipython3", 306 | "version": "3.11.5" 307 | } 308 | }, 309 | "nbformat": 4, 310 | "nbformat_minor": 5 311 | } 312 | -------------------------------------------------------------------------------- /scripts/knowledge_base.py: -------------------------------------------------------------------------------- 1 | import json 2 | import boto3 3 | import time 4 | from botocore.exceptions import ClientError 5 | from opensearchpy import OpenSearch, RequestsHttpConnection, AWSV4SignerAuth, RequestError 6 | import pprint 7 | from retrying import retry 8 | 9 | valid_embedding_models = ["cohere.embed-multilingual-v3", "cohere.embed-english-v3", "amazon.titan-embed-text-v1"] 10 | pp = pprint.PrettyPrinter(indent=2) 11 | 12 | 13 | def interactive_sleep(seconds: int): 14 | """ 15 | Support functionality to induce an artificial 'sleep' to the code in order to wait for resources to be available 16 | Args: 17 | seconds (int): number of seconds to sleep for 18 | """ 19 | dots = '' 20 | for i in range(seconds): 21 | dots += '.' 22 | print(dots, end='\r') 23 | time.sleep(1) 24 | 25 | 26 | class BedrockKnowledgeBase: 27 | """ 28 | Support class that allows for: 29 | - creation (or retrieval) of a Knowledge Base for Amazon Bedrock with all its pre-requisites 30 | (including OSS, IAM roles and Permissions and S3 bucket) 31 | - Ingestion of data into the Knowledge Base 32 | - Deletion of all resources created 33 | """ 34 | def __init__( 35 | self, 36 | kb_name, 37 | kb_description=None, 38 | data_bucket_name=None, 39 | embedding_model="amazon.titan-embed-text-v1" 40 | ): 41 | """ 42 | Class initializer 43 | Args: 44 | kb_name (str): the knowledge base name 45 | kb_description (str): knowledge base description 46 | data_bucket_name (str): name of s3 bucket to connect with knowledge base 47 | embedding_model (str): embedding model to use 48 | """ 49 | boto3_session = boto3.session.Session() 50 | self.region_name = boto3_session.region_name 51 | self.iam_client = boto3_session.client('iam') 52 | self.account_number = boto3.client('sts').get_caller_identity().get('Account') 53 | self.suffix = str(self.account_number)[:4] 54 | self.identity = boto3.client('sts').get_caller_identity()['Arn'] 55 | self.aoss_client = boto3_session.client('opensearchserverless') 56 | self.s3_client = boto3.client('s3') 57 | self.bedrock_agent_client = boto3.client('bedrock-agent') 58 | credentials = boto3.Session().get_credentials() 59 | self.awsauth = AWSV4SignerAuth(credentials, self.region_name, 'aoss') 60 | 61 | self.kb_name = kb_name 62 | self.kb_description = kb_description 63 | if data_bucket_name is not None: 64 | self.bucket_name = data_bucket_name 65 | else: 66 | self.bucket_name = f"{self.kb_name}-{self.suffix}" 67 | if embedding_model not in valid_embedding_models: 68 | valid_embeddings_str = str(valid_embedding_models) 69 | raise ValueError(f"Invalid embedding model. Your embedding model should be one of {valid_embeddings_str}") 70 | self.embedding_model = embedding_model 71 | self.encryption_policy_name = f"bedrock-sample-rag-sp-{self.suffix}" 72 | self.network_policy_name = f"bedrock-sample-rag-np-{self.suffix}" 73 | self.access_policy_name = f'bedrock-sample-rag-ap-{self.suffix}' 74 | self.kb_execution_role_name = f'AmazonBedrockExecutionRoleForKnowledgeBase_{self.suffix}' 75 | self.fm_policy_name = f'AmazonBedrockFoundationModelPolicyForKnowledgeBase_{self.suffix}' 76 | self.s3_policy_name = f'AmazonBedrockS3PolicyForKnowledgeBase_{self.suffix}' 77 | self.oss_policy_name = f'AmazonBedrockOSSPolicyForKnowledgeBase_{self.suffix}' 78 | 79 | self.vector_store_name = f'bedrock-sample-rag-{self.suffix}' 80 | self.index_name = f"bedrock-sample-rag-index-{self.suffix}" 81 | print("========================================================================================") 82 | print(f"Step 1 - Creating or retrieving {self.bucket_name} S3 bucket for Knowledge Base documents") 83 | self.create_s3_bucket() 84 | print("========================================================================================") 85 | print(f"Step 2 - Creating Knowledge Base Execution Role ({self.kb_execution_role_name}) and Policies") 86 | self.bedrock_kb_execution_role = self.create_bedrock_kb_execution_role() 87 | print("========================================================================================") 88 | print(f"Step 3 - Creating OSS encryption, network and data access policies") 89 | self.encryption_policy, self.network_policy, self.access_policy = self.create_policies_in_oss() 90 | print("========================================================================================") 91 | print(f"Step 4 - Creating OSS Collection (this step takes a couple of minutes to complete)") 92 | self.host, self.collection, self.collection_id, self.collection_arn = self.create_oss() 93 | # Build the OpenSearch client 94 | self.oss_client = OpenSearch( 95 | hosts=[{'host': self.host, 'port': 443}], 96 | http_auth=self.awsauth, 97 | use_ssl=True, 98 | verify_certs=True, 99 | connection_class=RequestsHttpConnection, 100 | timeout=300 101 | ) 102 | print("========================================================================================") 103 | print(f"Step 5 - Creating OSS Vector Index") 104 | self.create_vector_index() 105 | print("========================================================================================") 106 | print(f"Step 6 - Creating Knowledge Base") 107 | self.knowledge_base, self.data_source = self.create_knowledge_base() 108 | print("========================================================================================") 109 | 110 | def create_s3_bucket(self): 111 | """ 112 | Check if bucket exists, and if not create S3 bucket for knowledge base data source 113 | """ 114 | try: 115 | self.s3_client.head_bucket(Bucket=self.bucket_name) 116 | print(f'Bucket {self.bucket_name} already exists - retrieving it!') 117 | except ClientError as e: 118 | print(f'Creating bucket {self.bucket_name}') 119 | if self.region_name == "us-east-1": 120 | self.s3_client.create_bucket( 121 | Bucket=self.bucket_name 122 | ) 123 | else: 124 | self.s3_client.create_bucket( 125 | Bucket=self.bucket_name, 126 | CreateBucketConfiguration={'LocationConstraint': self.region_name} 127 | ) 128 | 129 | def create_bedrock_kb_execution_role(self): 130 | """ 131 | Create Knowledge Base Execution IAM Role and its required policies. 132 | If role and/or policies already exist, retrieve them 133 | Returns: 134 | IAM role 135 | """ 136 | foundation_model_policy_document = { 137 | "Version": "2012-10-17", 138 | "Statement": [ 139 | { 140 | "Effect": "Allow", 141 | "Action": [ 142 | "bedrock:InvokeModel", 143 | ], 144 | "Resource": [ 145 | f"arn:aws:bedrock:{self.region_name}::foundation-model/{self.embedding_model}" 146 | ] 147 | } 148 | ] 149 | } 150 | 151 | s3_policy_document = { 152 | "Version": "2012-10-17", 153 | "Statement": [ 154 | { 155 | "Effect": "Allow", 156 | "Action": [ 157 | "s3:GetObject", 158 | "s3:ListBucket" 159 | ], 160 | "Resource": [ 161 | f"arn:aws:s3:::{self.bucket_name}", 162 | f"arn:aws:s3:::{self.bucket_name}/*" 163 | ], 164 | "Condition": { 165 | "StringEquals": { 166 | "aws:ResourceAccount": f"{self.account_number}" 167 | } 168 | } 169 | } 170 | ] 171 | } 172 | 173 | assume_role_policy_document = { 174 | "Version": "2012-10-17", 175 | "Statement": [ 176 | { 177 | "Effect": "Allow", 178 | "Principal": { 179 | "Service": "bedrock.amazonaws.com" 180 | }, 181 | "Action": "sts:AssumeRole" 182 | } 183 | ] 184 | } 185 | try: 186 | # create policies based on the policy documents 187 | fm_policy = self.iam_client.create_policy( 188 | PolicyName=self.fm_policy_name, 189 | PolicyDocument=json.dumps(foundation_model_policy_document), 190 | Description='Policy for accessing foundation model', 191 | ) 192 | except self.iam_client.exceptions.EntityAlreadyExistsException: 193 | fm_policy = self.iam_client.get_policy( 194 | PolicyArn=f"arn:aws:iam::{self.account_number}:policy/{self.fm_policy_name}" 195 | ) 196 | 197 | try: 198 | s3_policy = self.iam_client.create_policy( 199 | PolicyName=self.s3_policy_name, 200 | PolicyDocument=json.dumps(s3_policy_document), 201 | Description='Policy for reading documents from s3') 202 | except self.iam_client.exceptions.EntityAlreadyExistsException: 203 | s3_policy = self.iam_client.get_policy( 204 | PolicyArn=f"arn:aws:iam::{self.account_number}:policy/{self.s3_policy_name}" 205 | ) 206 | # create bedrock execution role 207 | try: 208 | bedrock_kb_execution_role = self.iam_client.create_role( 209 | RoleName=self.kb_execution_role_name, 210 | AssumeRolePolicyDocument=json.dumps(assume_role_policy_document), 211 | Description='Amazon Bedrock Knowledge Base Execution Role for accessing OSS and S3', 212 | MaxSessionDuration=3600 213 | ) 214 | except self.iam_client.exceptions.EntityAlreadyExistsException: 215 | bedrock_kb_execution_role = self.iam_client.get_role( 216 | RoleName=self.kb_execution_role_name 217 | ) 218 | # fetch arn of the policies and role created above 219 | s3_policy_arn = s3_policy["Policy"]["Arn"] 220 | fm_policy_arn = fm_policy["Policy"]["Arn"] 221 | 222 | # attach policies to Amazon Bedrock execution role 223 | self.iam_client.attach_role_policy( 224 | RoleName=bedrock_kb_execution_role["Role"]["RoleName"], 225 | PolicyArn=fm_policy_arn 226 | ) 227 | self.iam_client.attach_role_policy( 228 | RoleName=bedrock_kb_execution_role["Role"]["RoleName"], 229 | PolicyArn=s3_policy_arn 230 | ) 231 | return bedrock_kb_execution_role 232 | 233 | def create_oss_policy_attach_bedrock_execution_role(self, collection_id): 234 | """ 235 | Create OpenSearch Serverless policy and attach it to the Knowledge Base Execution role. 236 | If policy already exists, attaches it 237 | """ 238 | # define oss policy document 239 | oss_policy_document = { 240 | "Version": "2012-10-17", 241 | "Statement": [ 242 | { 243 | "Effect": "Allow", 244 | "Action": [ 245 | "aoss:APIAccessAll" 246 | ], 247 | "Resource": [ 248 | f"arn:aws:aoss:{self.region_name}:{self.account_number}:collection/{collection_id}" 249 | ] 250 | } 251 | ] 252 | } 253 | 254 | oss_policy_arn = f"arn:aws:iam::{self.account_number}:policy/{self.oss_policy_name}" 255 | created = False 256 | try: 257 | self.iam_client.create_policy( 258 | PolicyName=self.oss_policy_name, 259 | PolicyDocument=json.dumps(oss_policy_document), 260 | Description='Policy for accessing opensearch serverless', 261 | ) 262 | created = True 263 | except self.iam_client.exceptions.EntityAlreadyExistsException: 264 | print(f"Policy {oss_policy_arn} already exists, skipping creation") 265 | print("Opensearch serverless arn: ", oss_policy_arn) 266 | 267 | self.iam_client.attach_role_policy( 268 | RoleName=self.bedrock_kb_execution_role["Role"]["RoleName"], 269 | PolicyArn=oss_policy_arn 270 | ) 271 | return created 272 | 273 | def create_policies_in_oss(self): 274 | """ 275 | Create OpenSearch Serverless encryption, network and data access policies. 276 | If policies already exist, retrieve them 277 | """ 278 | try: 279 | encryption_policy = self.aoss_client.create_security_policy( 280 | name=self.encryption_policy_name, 281 | policy=json.dumps( 282 | { 283 | 'Rules': [{'Resource': ['collection/' + self.vector_store_name], 284 | 'ResourceType': 'collection'}], 285 | 'AWSOwnedKey': True 286 | }), 287 | type='encryption' 288 | ) 289 | except self.aoss_client.exceptions.ConflictException: 290 | encryption_policy = self.aoss_client.get_security_policy( 291 | name=self.encryption_policy_name, 292 | type='encryption' 293 | ) 294 | 295 | try: 296 | network_policy = self.aoss_client.create_security_policy( 297 | name=self.network_policy_name, 298 | policy=json.dumps( 299 | [ 300 | {'Rules': [{'Resource': ['collection/' + self.vector_store_name], 301 | 'ResourceType': 'collection'}], 302 | 'AllowFromPublic': True} 303 | ]), 304 | type='network' 305 | ) 306 | except self.aoss_client.exceptions.ConflictException: 307 | network_policy = self.aoss_client.get_security_policy( 308 | name=self.network_policy_name, 309 | type='network' 310 | ) 311 | 312 | try: 313 | access_policy = self.aoss_client.create_access_policy( 314 | name=self.access_policy_name, 315 | policy=json.dumps( 316 | [ 317 | { 318 | 'Rules': [ 319 | { 320 | 'Resource': ['collection/' + self.vector_store_name], 321 | 'Permission': [ 322 | 'aoss:CreateCollectionItems', 323 | 'aoss:DeleteCollectionItems', 324 | 'aoss:UpdateCollectionItems', 325 | 'aoss:DescribeCollectionItems'], 326 | 'ResourceType': 'collection' 327 | }, 328 | { 329 | 'Resource': ['index/' + self.vector_store_name + '/*'], 330 | 'Permission': [ 331 | 'aoss:CreateIndex', 332 | 'aoss:DeleteIndex', 333 | 'aoss:UpdateIndex', 334 | 'aoss:DescribeIndex', 335 | 'aoss:ReadDocument', 336 | 'aoss:WriteDocument'], 337 | 'ResourceType': 'index' 338 | }], 339 | 'Principal': [self.identity, self.bedrock_kb_execution_role['Role']['Arn']], 340 | 'Description': 'Easy data policy'} 341 | ]), 342 | type='data' 343 | ) 344 | except self.aoss_client.exceptions.ConflictException: 345 | access_policy = self.aoss_client.get_access_policy( 346 | name=self.access_policy_name, 347 | type='data' 348 | ) 349 | 350 | return encryption_policy, network_policy, access_policy 351 | 352 | def create_oss(self): 353 | """ 354 | Create OpenSearch Serverless Collection. If already existent, retrieve 355 | """ 356 | try: 357 | collection = self.aoss_client.create_collection(name=self.vector_store_name, type='VECTORSEARCH') 358 | collection_id = collection['createCollectionDetail']['id'] 359 | collection_arn = collection['createCollectionDetail']['arn'] 360 | except self.aoss_client.exceptions.ConflictException: 361 | collection = self.aoss_client.batch_get_collection(names=[self.vector_store_name])['collectionDetails'][0] 362 | pp.pprint(collection) 363 | collection_id = collection['id'] 364 | collection_arn = collection['arn'] 365 | pp.pprint(collection) 366 | 367 | # Get the OpenSearch serverless collection URL 368 | host = collection_id + '.' + self.region_name + '.aoss.amazonaws.com' 369 | print(host) 370 | # wait for collection creation 371 | # This can take couple of minutes to finish 372 | response = self.aoss_client.batch_get_collection(names=[self.vector_store_name]) 373 | # Periodically check collection status 374 | while (response['collectionDetails'][0]['status']) == 'CREATING': 375 | print('Creating collection...') 376 | interactive_sleep(30) 377 | response = self.aoss_client.batch_get_collection(names=[self.vector_store_name]) 378 | print('\nCollection successfully created:') 379 | pp.pprint(response["collectionDetails"]) 380 | # create opensearch serverless access policy and attach it to Bedrock execution role 381 | try: 382 | created = self.create_oss_policy_attach_bedrock_execution_role(collection_id) 383 | if created: 384 | # It can take up to a minute for data access rules to be enforced 385 | print("Sleeping for a minute to ensure data access rules have been enforced") 386 | interactive_sleep(60) 387 | return host, collection, collection_id, collection_arn 388 | except Exception as e: 389 | print("Policy already exists") 390 | pp.pprint(e) 391 | 392 | def create_vector_index(self): 393 | """ 394 | Create OpenSearch Serverless vector index. If existent, ignore 395 | """ 396 | body_json = { 397 | "settings": { 398 | "index.knn": "true", 399 | "number_of_shards": 1, 400 | "knn.algo_param.ef_search": 512, 401 | "number_of_replicas": 0, 402 | }, 403 | "mappings": { 404 | "properties": { 405 | "vector": { 406 | "type": "knn_vector", 407 | "dimension": 1536, 408 | "method": { 409 | "name": "hnsw", 410 | "engine": "faiss", 411 | "space_type": "l2" 412 | }, 413 | }, 414 | "text": { 415 | "type": "text" 416 | }, 417 | "text-metadata": { 418 | "type": "text"} 419 | } 420 | } 421 | } 422 | 423 | # Create index 424 | try: 425 | response = self.oss_client.indices.create(index=self.index_name, body=json.dumps(body_json)) 426 | print('\nCreating index:') 427 | pp.pprint(response) 428 | 429 | # index creation can take up to a minute 430 | interactive_sleep(60) 431 | except RequestError as e: 432 | # you can delete the index if its already exists 433 | # oss_client.indices.delete(index=index_name) 434 | print( 435 | f'Error while trying to create the index, with error {e.error}\nyou may unmark the delete above to ' 436 | f'delete, and recreate the index') 437 | 438 | @retry(wait_random_min=1000, wait_random_max=2000, stop_max_attempt_number=7) 439 | def create_knowledge_base(self): 440 | """ 441 | Create Knowledge Base and its Data Source. If existent, retrieve 442 | """ 443 | opensearch_serverless_configuration = { 444 | "collectionArn": self.collection_arn, 445 | "vectorIndexName": self.index_name, 446 | "fieldMapping": { 447 | "vectorField": "vector", 448 | "textField": "text", 449 | "metadataField": "text-metadata" 450 | } 451 | } 452 | 453 | # Ingest strategy - How to ingest data from the data source 454 | chunking_strategy_configuration = { 455 | "chunkingStrategy": "FIXED_SIZE", 456 | "fixedSizeChunkingConfiguration": { 457 | "maxTokens": 512, 458 | "overlapPercentage": 20 459 | } 460 | } 461 | 462 | # The data source to ingest documents from, into the OpenSearch serverless knowledge base index 463 | s3_configuration = { 464 | "bucketArn": f"arn:aws:s3:::{self.bucket_name}", 465 | # "inclusionPrefixes":["*.*"] # you can use this if you want to create a KB using data within s3 prefixes. 466 | } 467 | 468 | # The embedding model used by Bedrock to embed ingested documents, and realtime prompts 469 | embedding_model_arn = f"arn:aws:bedrock:{self.region_name}::foundation-model/{self.embedding_model}" 470 | try: 471 | create_kb_response = self.bedrock_agent_client.create_knowledge_base( 472 | name=self.kb_name, 473 | description=self.kb_description, 474 | roleArn=self.bedrock_kb_execution_role['Role']['Arn'], 475 | knowledgeBaseConfiguration={ 476 | "type": "VECTOR", 477 | "vectorKnowledgeBaseConfiguration": { 478 | "embeddingModelArn": embedding_model_arn 479 | } 480 | }, 481 | storageConfiguration={ 482 | "type": "OPENSEARCH_SERVERLESS", 483 | "opensearchServerlessConfiguration": opensearch_serverless_configuration 484 | } 485 | ) 486 | kb = create_kb_response["knowledgeBase"] 487 | pp.pprint(kb) 488 | except self.bedrock_agent_client.exceptions.ConflictException: 489 | kbs = self.bedrock_agent_client.list_knowledge_bases( 490 | maxResults=100 491 | ) 492 | kb_id = None 493 | for kb in kbs['knowledgeBaseSummaries']: 494 | if kb['name'] == self.kb_name: 495 | kb_id = kb['knowledgeBaseId'] 496 | response = self.bedrock_agent_client.get_knowledge_base(knowledgeBaseId=kb_id) 497 | kb = response['knowledgeBase'] 498 | pp.pprint(kb) 499 | 500 | # Create a DataSource in KnowledgeBase 501 | try: 502 | create_ds_response = self.bedrock_agent_client.create_data_source( 503 | name=self.kb_name, 504 | description=self.kb_description, 505 | knowledgeBaseId=kb['knowledgeBaseId'], 506 | dataSourceConfiguration={ 507 | "type": "S3", 508 | "s3Configuration": s3_configuration 509 | }, 510 | vectorIngestionConfiguration={ 511 | "chunkingConfiguration": chunking_strategy_configuration 512 | } 513 | ) 514 | ds = create_ds_response["dataSource"] 515 | pp.pprint(ds) 516 | except self.bedrock_agent_client.exceptions.ConflictException: 517 | ds_id = self.bedrock_agent_client.list_data_sources( 518 | knowledgeBaseId=kb['knowledgeBaseId'], 519 | maxResults=100 520 | )['dataSourceSummaries'][0]['dataSourceId'] 521 | get_ds_response = self.bedrock_agent_client.get_data_source( 522 | dataSourceId=ds_id, 523 | knowledgeBaseId=kb['knowledgeBaseId'] 524 | ) 525 | ds = get_ds_response["dataSource"] 526 | pp.pprint(ds) 527 | return kb, ds 528 | 529 | def start_ingestion_job(self): 530 | """ 531 | Start an ingestion job to synchronize data from an S3 bucket to the Knowledge Base 532 | """ 533 | # Start an ingestion job 534 | start_job_response = self.bedrock_agent_client.start_ingestion_job( 535 | knowledgeBaseId=self.knowledge_base['knowledgeBaseId'], 536 | dataSourceId=self.data_source["dataSourceId"] 537 | ) 538 | job = start_job_response["ingestionJob"] 539 | pp.pprint(job) 540 | # Get job 541 | while job['status'] != 'COMPLETE': 542 | get_job_response = self.bedrock_agent_client.get_ingestion_job( 543 | knowledgeBaseId=self.knowledge_base['knowledgeBaseId'], 544 | dataSourceId=self.data_source["dataSourceId"], 545 | ingestionJobId=job["ingestionJobId"] 546 | ) 547 | job = get_job_response["ingestionJob"] 548 | pp.pprint(job) 549 | interactive_sleep(40) 550 | 551 | def get_knowledge_base_id(self): 552 | """ 553 | Get Knowledge Base Id 554 | """ 555 | pp.pprint(self.knowledge_base["knowledgeBaseId"]) 556 | return self.knowledge_base["knowledgeBaseId"] 557 | 558 | def get_bucket_name(self): 559 | """ 560 | Get the name of the bucket connected with the Knowledge Base Data Source 561 | """ 562 | pp.pprint(f"Bucket connected with KB: {self.bucket_name}") 563 | return self.bucket_name 564 | 565 | def delete_kb(self, delete_s3_bucket=False, delete_iam_roles_and_policies=True): 566 | """ 567 | Delete the Knowledge Base resources 568 | Args: 569 | delete_s3_bucket (bool): boolean to indicate if s3 bucket should also be deleted 570 | delete_iam_roles_and_policies (bool): boolean to indicate if IAM roles and Policies should also be deleted 571 | """ 572 | self.bedrock_agent_client.delete_data_source( 573 | dataSourceId=self.data_source["dataSourceId"], 574 | knowledgeBaseId=self.knowledge_base['knowledgeBaseId'] 575 | ) 576 | self.bedrock_agent_client.delete_knowledge_base( 577 | knowledgeBaseId=self.knowledge_base['knowledgeBaseId'] 578 | ) 579 | self.oss_client.indices.delete(index=self.index_name) 580 | self.aoss_client.delete_collection(id=self.collection_id) 581 | self.aoss_client.delete_access_policy( 582 | type="data", 583 | name=self.access_policy_name 584 | ) 585 | self.aoss_client.delete_security_policy( 586 | type="network", 587 | name=self.network_policy_name 588 | ) 589 | self.aoss_client.delete_security_policy( 590 | type="encryption", 591 | name=self.encryption_policy_name 592 | ) 593 | if delete_s3_bucket: 594 | self.delete_s3() 595 | if delete_iam_roles_and_policies: 596 | self.delete_iam_roles_and_policies() 597 | 598 | def delete_iam_roles_and_policies(self): 599 | """ 600 | Delete IAM Roles and policies used by the Knowledge Base 601 | """ 602 | fm_policy_arn = f"arn:aws:iam::{self.account_number}:policy/{self.fm_policy_name}" 603 | s3_policy_arn = f"arn:aws:iam::{self.account_number}:policy/{self.s3_policy_name}" 604 | oss_policy_arn = f"arn:aws:iam::{self.account_number}:policy/{self.oss_policy_name}" 605 | self.iam_client.detach_role_policy( 606 | RoleName=self.kb_execution_role_name, 607 | PolicyArn=s3_policy_arn 608 | ) 609 | self.iam_client.detach_role_policy( 610 | RoleName=self.kb_execution_role_name, 611 | PolicyArn=fm_policy_arn 612 | ) 613 | self.iam_client.detach_role_policy( 614 | RoleName=self.kb_execution_role_name, 615 | PolicyArn=oss_policy_arn 616 | ) 617 | self.iam_client.delete_role(RoleName=self.kb_execution_role_name) 618 | self.iam_client.delete_policy(PolicyArn=s3_policy_arn) 619 | self.iam_client.delete_policy(PolicyArn=fm_policy_arn) 620 | self.iam_client.delete_policy(PolicyArn=oss_policy_arn) 621 | return 0 622 | 623 | def delete_s3(self): 624 | """ 625 | Delete the objects contained in the Knowledge Base S3 bucket. 626 | Once the bucket is empty, delete the bucket 627 | """ 628 | objects = self.s3_client.list_objects(Bucket=self.bucket_name) 629 | if 'Contents' in objects: 630 | for obj in objects['Contents']: 631 | self.s3_client.delete_object(Bucket=self.bucket_name, Key=obj['Key']) 632 | self.s3_client.delete_bucket(Bucket=self.bucket_name) -------------------------------------------------------------------------------- /notebooks/reasoning_with_langgraph_bedrock_workshop.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "52b30b6b-e2c4-4c9c-905f-530125707d29", 6 | "metadata": {}, 7 | "source": [ 8 | "# Build an Multi-agentic Services with Orchestration for Reasoning\n", 9 | "\n", 10 | "This workshop will focus on creating a sophisticated system of coordinated AI agents. We'll incorporate recent breakthroughs in generative AI to enhance the system's reasoning capabilities. By leveraging the power of LangGraph and the advanced language models available through Amazon Bedrock, including Claude-3.5 and LLaMA-3.1, we will create a versatile and capable system that can tackle a wide range of challenges.\n", 11 | "\n", 12 | "\n", 13 | "\n", 14 | "Key Features:\n", 15 | "\n", 16 | "* Agentic Service: The agent will be designed as a service, allowing for seamless integration and deployment in various applications.\n", 17 | "* Dynamic Prompt Rewriting: The agent will dynamically rewrite prompts to optimize the responses from the underlying language models, ensuring more accurate and informative outputs.\n", 18 | "* Adaptive Routing: Inspired by the Semantic Router, the routing agent will intelligently route requests to retrieval, web search, or pre-trained LLMs for the most desirable answers, leveraging the strengths of each method for optimal performance. This adaptive routing mechanism will ensure that the agent can effectively handle a diverse set of queries and tasks.\n", 19 | "* Hallucination Grader: The agent will include a hallucination grader component to assess the reliability of the generated responses. This will help identify and correct any hallucinations or incomplete answers.\n", 20 | "* Human Involvement: If needed, the agent will be able to involve human subject matter experts to provide additional verification and correction of the responses, further improving the trustworthiness and reliability of the system.\n", 21 | "\n", 22 | "By combining these advanced reasoning techniques, the agentic services with orchestration will be able to provide more accurate and informative responses to challenging queries. This will be a significant step forward in the development of intelligent agents that can truly understand and respond to complex questions.\n", 23 | "\n", 24 | "To build this powerful system, we will use LangGraph to create complex, multi-step workflows that involve language models and other components. This will allow us to develop a flexible and scalable system that can handle a wide range of tasks.\n", 25 | "\n", 26 | "To further enhance our capabilities, we will port the original notebook to utilize Amazon Bedrock for LLM inference. This will enable us to leverage the cloud processing power and take advantage of the advanced language models available through Amazon Bedrock, such as Claude-3 and LLaMA-3. By harnessing the power of these cutting-edge language models, we will be able to push the boundaries of what is possible with intelligent agents.\n", 27 | "\n", 28 | "While the choice of vector stores (local chromaDB) will remain unchanged for now, we will explore how to scale this part in future blog posts, ensuring that our system can handle ever-growing amounts of data and information.\n", 29 | "\n", 30 | "Join us in this exciting workshop as we embark on a journey to create an intelligent agent that redefines the boundaries of what is possible with language-based AI systems. Together, we will explore the latest advancements in the field and push the limits of what can be achieved.\n", 31 | "\n" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": null, 37 | "id": "f1665046-6a6e-43f8-a6d7-46e41c13ed18", 38 | "metadata": { 39 | "scrolled": true 40 | }, 41 | "outputs": [], 42 | "source": [ 43 | "!pip install -q -r ../requirements2.txt -U " 44 | ] 45 | }, 46 | { 47 | "cell_type": "markdown", 48 | "id": "2eaf8ec0-2a7c-4283-a03f-e8ead72e49b1", 49 | "metadata": {}, 50 | "source": [ 51 | "**Restart the kernel after pkg install**" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": null, 57 | "id": "3753831a-fefc-4097-9bda-3e0e1393f7a4", 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "import os\n", 62 | "os._exit(00)" 63 | ] 64 | }, 65 | { 66 | "cell_type": "markdown", 67 | "id": "ee6fa942-0a95-40a1-87f7-c3d98806a015", 68 | "metadata": {}, 69 | "source": [ 70 | "## 1. Setting Up API keys or tokens \n", 71 | "\n", 72 | "To access various services, such as Amazon Bedrock for Large Language Models (LLMs) and embedding models, Tavily web search engine, and optional Langchain, you will need to set up and obtain the necessary API keys or tokens. These API keys and tokens serve as authentication credentials that allow your application to securely connect and interact with the respective services." 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": null, 78 | "id": "c41bbfd4-9225-4747-9e3a-262af2b8dca7", 79 | "metadata": { 80 | "tags": [] 81 | }, 82 | "outputs": [], 83 | "source": [ 84 | "import sys\n", 85 | "import os\n", 86 | "import boto3\n", 87 | "import json\n", 88 | "import requests\n", 89 | "\n", 90 | "\n", 91 | "aws_region = \"us-west-2\" # choose your region you operate in\n", 92 | "os.environ['TAVILY_API_KEY'] = tavily_ai_api_key = 'tvly-NA' # For extra search result. Optional\n", 93 | "os.environ['OPENAI_API_KEY'] = openai_api_key = 'sk-NA' # Only when you elect to use OpenAI's ada as embedding model. Otherwise you just need to assign an empty key. \n", 94 | "# Temp image file\n", 95 | "temp_gen_image = \"./delme.png\"\n", 96 | "markdown_filename = \"./blogpost.md\"\n", 97 | "\n", 98 | "module_paths = [\"./\", \"../scripts\"]\n", 99 | "for module_path in module_paths:\n", 100 | " sys.path.append(os.path.abspath(module_path))\n", 101 | "\n", 102 | "from blog_writer import *\n", 103 | "from bedrock import *\n", 104 | "\n", 105 | "#os.environ[\"LANGCHAIN_TRACING_V2\"] = \"true\"\n", 106 | "#os.environ[\"LANGCHAIN_ENDPOINT\"] = \"https://api.smith.langchain.com\"\n", 107 | "#os.environ[\"LANGCHAIN_API_KEY\"] = langchain_api_key" 108 | ] 109 | }, 110 | { 111 | "cell_type": "markdown", 112 | "id": "c1a16bda-dd03-4e7a-abb4-1da18f0f55af", 113 | "metadata": {}, 114 | "source": [ 115 | "## 2. Creating a Bedrock Runtime Client\n", 116 | "We'll create a Bedrock runtime client to connect to the Amazon Bedrock service. Bedrock, a fully managed service by AWS, allows developers to build and deploy generative AI models like large language models (LLMs). This client will enable us to leverage pre-trained LLMs from Amazon, such as the powerful LLaMA3 model from Meta.\n", 117 | "\n", 118 | "Connecting to Bedrock is crucial for building our scalable and secure RAG agent, as it provides the necessary language model for generation capabilities. With the Bedrock runtime client in place, we can integrate LLaMA3 into our workflow and use its advanced natural language processing capabilities to generate accurate responses." 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": null, 124 | "id": "8293dd71-f502-46f6-b039-935ff458accb", 125 | "metadata": { 126 | "tags": [] 127 | }, 128 | "outputs": [], 129 | "source": [ 130 | "### Select models\n", 131 | "import ipywidgets as widgets\n", 132 | "from ipywidgets import interact, interactive, fixed\n", 133 | "\n", 134 | "options = [\"mistral.mistral-large-2407-v1:0\", \"anthropic.claude-3-haiku-20240307-v1:0\", \"anthropic.claude-3-5-sonnet-20240620-v1:0\", \"meta.llama3-1-70b-instruct-v1:0\"]\n", 135 | "# Create the dropdown widget\n", 136 | "dropdown = widgets.Dropdown(\n", 137 | " options=options,\n", 138 | " value=options[1],\n", 139 | " description='Choose an option:',\n", 140 | " disabled=False,\n", 141 | ")\n", 142 | "\n", 143 | "# Display the dropdown widget\n", 144 | "display(dropdown)" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": null, 150 | "id": "ed332f82-89e8-4a99-aa84-1a4533042ee4", 151 | "metadata": { 152 | "tags": [] 153 | }, 154 | "outputs": [], 155 | "source": [ 156 | "model_id = dropdown.value\n", 157 | "llm = get_llm(model_id)\n", 158 | "model_id_l31 = 'meta.llama3-1-70b-instruct-v1:0'\n", 159 | "model_id_c35 = \"anthropic.claude-3-sonnet-20240229-v1:0\" # Due to model access restriction #'anthropic.claude-3-5-sonnet-20240620-v1:0' \n", 160 | "model_id_mistral_large = 'mistral.mistral-large-2407-v1:0'\n", 161 | "# Choose multiple models for different purpose to deversify and avoid potential bias \n", 162 | "llm = get_llm(model_id)\n", 163 | "llm_llama31 = get_llm(model_id_l31)\n", 164 | "llm_claude35 = get_llm(model_id_c35 )\n", 165 | "llm_mistral = get_llm(model_id_mistral_large)" 166 | ] 167 | }, 168 | { 169 | "cell_type": "markdown", 170 | "id": "b058bb22-85c0-4f4f-9795-fd686ff9466a", 171 | "metadata": {}, 172 | "source": [ 173 | "## 3. Create agentic services with multi-agent capability\n", 174 | "\n", 175 | "Creating agentic services with multi-agent capability using Amazon Bedrock, Converse API, and LangChain can be a powerful approach to building intelligent and collaborative systems. Amazon Bedrock provides a foundation for developing large language models (LLMs) and integrating them into applications, while the Converse API enables seamless communication between these models and external services. LangChain, on the other hand, offers a framework for building complex, multi-agent systems that can leverage the capabilities of various LLMs and other AI components. By combining these tools, developers can create agentic services that can engage in dynamic, context-aware interactions, share knowledge, and coordinate their efforts to tackle complex tasks. This approach can be particularly useful in scenarios where a diverse set of specialized agents need to collaborate, such as in enterprise automation, customer service, or research and development." 176 | ] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "execution_count": null, 181 | "id": "1a35bb0a-ba58-4a31-a518-a3a067899c5f", 182 | "metadata": { 183 | "tags": [] 184 | }, 185 | "outputs": [], 186 | "source": [ 187 | "from IPython.display import Image, Markdown\n", 188 | "from langgraph.prebuilt import create_react_agent\n", 189 | "from chromadb import Documents, EmbeddingFunction, Embeddings\n", 190 | "from langchain_aws import BedrockEmbeddings\n", 191 | "from langchain.text_splitter import RecursiveCharacterTextSplitter\n", 192 | "from langchain_community.document_loaders import WebBaseLoader\n", 193 | "from langchain_community.vectorstores import Chroma\n", 194 | "from langchain_core.messages import SystemMessage, HumanMessage, ToolMessage, AnyMessage\n", 195 | "from typing import TypedDict\n", 196 | "from langgraph.graph import StateGraph, END\n", 197 | "from langgraph.checkpoint.memory import MemorySaver\n", 198 | "from botocore.exceptions import ClientError\n", 199 | "from langchain.prompts import PromptTemplate, ChatPromptTemplate\n", 200 | "from langchain_core.output_parsers import JsonOutputParser, StrOutputParser\n", 201 | "from langchain_community.tools.tavily_search import TavilySearchResults\n", 202 | "from langchain_community.tools import DuckDuckGoSearchResults" 203 | ] 204 | }, 205 | { 206 | "cell_type": "code", 207 | "execution_count": null, 208 | "id": "446454c8-42b3-4b91-8a49-d6dc3c1ef3a6", 209 | "metadata": { 210 | "tags": [] 211 | }, 212 | "outputs": [], 213 | "source": [ 214 | "config = Config(\n", 215 | " retries = dict(\n", 216 | " max_attempts = 10,\n", 217 | " total_max_attempts = 25,\n", 218 | " )\n", 219 | " )\n", 220 | "bedrock_client = boto3.client(\"bedrock-runtime\", config=config) " 221 | ] 222 | }, 223 | { 224 | "cell_type": "code", 225 | "execution_count": null, 226 | "id": "e3617811-2627-4c3f-aef4-2be4797b80e7", 227 | "metadata": { 228 | "tags": [] 229 | }, 230 | "outputs": [], 231 | "source": [ 232 | "class MyEmbeddingFunction(EmbeddingFunction):\n", 233 | " def __init__(self, client, region_name: str, model_id: str):\n", 234 | " self.embedder = BedrockEmbeddings(\n", 235 | " client=client,\n", 236 | " region_name=region_name,\n", 237 | " model_id=model_id\n", 238 | " )\n", 239 | " def embed_query(self, query: str) -> Embeddings:\n", 240 | " return self.embedder.embed_query(query)\n", 241 | " def embed_documents(self, documents: list[str]) -> Embeddings:\n", 242 | " return self.embedder.embed_documents(documents)\n", 243 | "\n", 244 | "class MultiAgentState(TypedDict):\n", 245 | " question: str\n", 246 | " question_type: str\n", 247 | " answer: str\n", 248 | " feedback: str\n", 249 | "\n", 250 | "\n", 251 | "memory = MemorySaver()\n", 252 | "embedding_model_id = \"amazon.titan-embed-text-v2:0\"\n", 253 | "\n", 254 | "####\n", 255 | "# Router\n", 256 | "###\n", 257 | "def route_question(state: MultiAgentState):\n", 258 | " print('route function execution')\n", 259 | " print(state)\n", 260 | " return state['question_type']\n", 261 | "\n", 262 | "\n", 263 | "####\n", 264 | "# rewrite the question\n", 265 | "####\n", 266 | "def rewrite_node(state: MultiAgentState):\n", 267 | " \"\"\"\n", 268 | " REwrite question from query to match domain expert\n", 269 | " Args:\n", 270 | " question (str): The user query\n", 271 | " Returns:\n", 272 | " promt (str): rewrite question to form an expert prompt\n", 273 | " \"\"\"\n", 274 | " print(\"---REWRITE QUESTION---\")\n", 275 | " c3_template = \"\"\"Rewrite the question by following the {{instruction}} to capture more precise and comprehensive intent from {question}.\n", 276 | " \n", 277 | " Identify the key purposes, concepts and entities in the original {{question}}. \n", 278 | " Rephrase the question to be more specific and focused, ensuring that the language is clear and unambiguous. \n", 279 | " Provide additional context or background information that may be helpful for web search or RAG system to better understand and respond to the question. \n", 280 | " Output your reqritten question only without answering it or repeating the riginal one.\n", 281 | " \n", 282 | " \"\"\"\n", 283 | " \n", 284 | " c3_prompt = ChatPromptTemplate.from_template(c3_template)\n", 285 | " #chain = ( c3_prompt | llm_c3 | StrOutputParser() | (lambda x: x.split(\"\\n\")))\n", 286 | " rewritten_chain = ( c3_prompt | llm | StrOutputParser() )\n", 287 | " rewritten_question = rewritten_chain.invoke({\"question\": state['question']})\n", 288 | " print(rewritten_question)\n", 289 | " if os.path.exists(temp_gen_image):\n", 290 | " os.remove(temp_gen_image)\n", 291 | " return {\"answer\": rewritten_question}\n", 292 | "\n", 293 | " \n", 294 | "#####\n", 295 | "# Router agent\n", 296 | "#####\n", 297 | "question_category_prompt = '''You are a senior specialist of analytical support. Your task is to classify the incoming questions. \n", 298 | "Depending on your answer, question will be routed to the right team, so your task is crucial for our team. \n", 299 | "There are 5 possible question types: \n", 300 | "- Vectorstore - Answer questions related to pre-indexed healthcare and medical research related topics stored in the vactorestore.\n", 301 | "- Websearch- Answer questions based on events happened recently, after most LLM's cut-off dates. \n", 302 | "- General - Answer questions for LLM or a few LLMs.\n", 303 | "- Text2image - Generate an image from text input.\n", 304 | "- Booking - Assist in restaurant reservation booking.\n", 305 | "- BlogWriter - Writer a blog post about the provided topic as a professional writer.\n", 306 | "Return in the output only one word (VECTORSTORE, WEBSEARCH, GENERAL, TEXT2IMAGE, BOOKING or BLOGWRITER).\n", 307 | "'''\n", 308 | "\n", 309 | "def router_node(state: MultiAgentState):\n", 310 | " print('Router node started execution')\n", 311 | " messages = [\n", 312 | " SystemMessage(content=question_category_prompt), \n", 313 | " HumanMessage(content=state['question'])\n", 314 | " ]\n", 315 | " response = llm.invoke(messages)\n", 316 | " print('Question type: %s' % response.content)\n", 317 | " return {\"question_type\": response.content}\n", 318 | "\n", 319 | "\n", 320 | "#####\n", 321 | "# Search agent\n", 322 | "#####\n", 323 | "def search_expert_node(state: MultiAgentState):\n", 324 | " tavily_tool = TavilySearchResults(max_results=5)\n", 325 | " duck_search = DuckDuckGoSearchResults()\n", 326 | "\n", 327 | " search_expert_system_prompt = '''\n", 328 | " You are an expert in LangChain and other technologies. \n", 329 | " Your goal is to answer questions based on results provided by search.\n", 330 | " You don't add anything yourself and provide only information baked by other sources. \n", 331 | " '''\n", 332 | " search_agent = create_react_agent(llm, [duck_search, tavily_tool],\n", 333 | " state_modifier = search_expert_system_prompt)\n", 334 | " messages = [HumanMessage(content=state['question'])]\n", 335 | " result = search_agent.invoke({\"messages\": messages})\n", 336 | " return {'answer': result['messages'][-1].content}\n", 337 | "\n", 338 | "\n", 339 | "#######\n", 340 | "# RAG\n", 341 | "#######\n", 342 | "def rag_node(state: MultiAgentState):\n", 343 | " urls = [\n", 344 | " \"https://www.ncbi.nlm.nih.gov/pmc/articles/PMC11127599/\",\n", 345 | " \"https://www.ncbi.nlm.nih.gov/pmc/articles/PMC11127585/\",\n", 346 | " \"https://www.ncbi.nlm.nih.gov/pmc/articles/PMC11127581/\"\n", 347 | " ]\n", 348 | " c3_template = \"\"\"You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. Use less than 10 sentences maximum and keep the answer concise. \n", 349 | " \n", 350 | " {context} \n", 351 | " \n", 352 | " Use these to craft an answer to the question: {question}\"\"\"\n", 353 | " c3_prompt = ChatPromptTemplate.from_template(c3_template)\n", 354 | " \n", 355 | " docs = [WebBaseLoader(url).load() for url in urls] \n", 356 | " docs_list = [item for sublist in docs for item in sublist]\n", 357 | " text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(\n", 358 | " chunk_size=4000, chunk_overlap=400\n", 359 | " )\n", 360 | " doc_splits = text_splitter.split_documents(docs_list)\n", 361 | "\n", 362 | " embedding_function = MyEmbeddingFunction(client = bedrock_client,\n", 363 | " region_name=aws_region,\n", 364 | " model_id=embedding_model_id)\n", 365 | " # Add to vectorDB\n", 366 | " vectorstore = Chroma.from_documents(\n", 367 | " documents=doc_splits,\n", 368 | " embedding=embedding_function,\n", 369 | " collection_name=\"rag-chroma-titan-embed-text-v2-1\",\n", 370 | " )\n", 371 | " retriever = vectorstore.as_retriever(\n", 372 | " search_type=\"mmr\",\n", 373 | " search_kwargs={'k': 3, 'lambda_mult': 0.25})\n", 374 | " rag_chain = c3_prompt | llm | StrOutputParser()\n", 375 | " documents = retriever.invoke(state['question'])\n", 376 | " generation = rag_chain.invoke({\"context\": documents, \"question\": state['question']})\n", 377 | " #generation = rag_chain.invoke({\"context\": documents, \"question\": state['answer']}) # Use the rewritten question instead\n", 378 | " return {'answer': generation}\n", 379 | "\n", 380 | "\n", 381 | "#####\n", 382 | "# LLM node\n", 383 | "####\n", 384 | "def llm_node(state: MultiAgentState):\n", 385 | " model_ids = [model_id_mistral_large , model_id_l31]\n", 386 | " max_tokens = 2048\n", 387 | " temperature = 0.01\n", 388 | " top_p = 0.95\n", 389 | "\n", 390 | " conversation = [\n", 391 | " {\n", 392 | " \"role\": \"user\",\n", 393 | " #\"system\": \"You are a domain expert who can understand the intent of user query and answer question truthful and professionally. Please, don't provide any unchecked information and just tell that you don't know if you don't have enough info.\",\n", 394 | " \"content\": [{\"text\": state['question']}],\n", 395 | " }\n", 396 | " ]\n", 397 | " try:\n", 398 | " # Send the message to the model, using a basic inference configuration.\n", 399 | " responses = []\n", 400 | " for model_id in model_ids:\n", 401 | " response = bedrock_client.converse(\n", 402 | " modelId=model_id,\n", 403 | " messages=conversation,\n", 404 | " inferenceConfig={\"maxTokens\": max_tokens, \"temperature\": temperature, \"topP\": top_p},\n", 405 | " )\n", 406 | " \n", 407 | " # Extract and print the response text.\n", 408 | " responses.append( response[\"output\"][\"message\"][\"content\"][0][\"text\"])\n", 409 | "\n", 410 | " ###\n", 411 | " # Combine the answers to form a unified one\n", 412 | " ###\n", 413 | " c3_template = \"\"\"Your are a domain expert and your goal is to Merge and eliminate redundant elements from {{responses}} that captures the essence of all input while adhering to the following the {{instruction}}.\n", 414 | " \n", 415 | " Aggregate relevant information from the provided context. \n", 416 | " Eliminate redundancies to ensure a concise response. \n", 417 | " Maintain fidelity to the original content. \n", 418 | " Add additional relevent info to the question or removing iirelevant information.\n", 419 | " \n", 420 | " \n", 421 | " {responses}\n", 422 | " \n", 423 | " \"\"\"\n", 424 | " \n", 425 | " messages = [\n", 426 | " SystemMessage(content=c3_template), \n", 427 | " HumanMessage(content=state['question'])\n", 428 | " ]\n", 429 | "\n", 430 | " return {'answer': llm_claude35.invoke(messages)}\n", 431 | " except (ClientError, Exception) as e:\n", 432 | " print(f\"ERROR: Can't invoke '{model_id}'. Reason: {e}\")\n", 433 | "\n", 434 | "\n", 435 | "####\n", 436 | "# Human in the loop\n", 437 | "###\n", 438 | "\n", 439 | "def human_feedback_node(state: MultiAgentState):\n", 440 | " editor_prompt = '''You're an editor and your goal is to provide the final answer to the customer, taking into account the feedback. \n", 441 | " You don't add any information on your own. You use friendly and professional tone. \n", 442 | " In the output please provide the final answer to the customer without additional comments.\n", 443 | " Here's all the information you need.\n", 444 | " \n", 445 | " \n", 446 | " Question from customer: \n", 447 | " ----\n", 448 | " {question}\n", 449 | " ----\n", 450 | " Draft answer:\n", 451 | " ----\n", 452 | " {answer}\n", 453 | " ----\n", 454 | " Feedback: \n", 455 | " ----\n", 456 | " {feedback}\n", 457 | " ----\n", 458 | " '''\n", 459 | " print(state)\n", 460 | " messages = [\n", 461 | " SystemMessage(content=editor_prompt.format(question = state['question'], answer = state['answer'], feedback = state['feedback']))\n", 462 | " ]\n", 463 | " response = llm.invoke(messages)\n", 464 | " return {\"answer\": response.content}\n", 465 | "\n", 466 | "def editor_node(state: MultiAgentState):\n", 467 | " pass\n", 468 | " print(state)\n", 469 | " messages = [\n", 470 | " SystemMessage(content=editor_prompt.format(question = state['question'], answer = state['answer'], feedback = state['feedback']))\n", 471 | " ]\n", 472 | " response = llm.invoke(messages)\n", 473 | " return {\"answer\": response.content}\n", 474 | "\n", 475 | "#####\n", 476 | "# multi-agent collaboration node\n", 477 | "#####\n", 478 | "def blog_writer_node(state: MultiAgentState):\n", 479 | " blog_crew = blogCrew(topic=state['answer'], model_id=model_id_c35)\n", 480 | " result = blog_crew.run()\n", 481 | "\n", 482 | " ## Werite to a Markdown file\n", 483 | " if os.path.exists(markdown_filename):\n", 484 | " os.remove(markdown_filename)\n", 485 | " # Create the Markdown format and Save the Markdown text to a file\n", 486 | " markdown_text = f\"# Sample Text\\n\\n{result.raw}\\n\\n![Image]({temp_gen_image})\"\n", 487 | " with open(markdown_filename, \"w\") as file:\n", 488 | " file.write(markdown_text)\n", 489 | "\n", 490 | " return {\"answer\": result}" 491 | ] 492 | }, 493 | { 494 | "cell_type": "markdown", 495 | "id": "671746d7-14f6-4c68-99cb-0f9b60573fcf", 496 | "metadata": {}, 497 | "source": [ 498 | "#### Additional functions" 499 | ] 500 | }, 501 | { 502 | "cell_type": "code", 503 | "execution_count": null, 504 | "id": "a015677c-9fce-45b1-8b6b-ce8b4f5fa09c", 505 | "metadata": { 506 | "tags": [] 507 | }, 508 | "outputs": [], 509 | "source": [ 510 | "#####\n", 511 | "# Hallucination grader\n", 512 | "#####\n", 513 | "from langchain.callbacks.base import BaseCallbackHandler\n", 514 | "import random\n", 515 | "import base64\n", 516 | "\n", 517 | "class MyCustomHandler(BaseCallbackHandler):\n", 518 | " def on_llm_end(self, response, **kwargs):\n", 519 | " print(f\"Response: {response}\")\n", 520 | " \n", 521 | "def hallucination_grader_node(state:MultiAgentState):\n", 522 | " c3_template = \"\"\"You are a grader assessing whether an answer is grounded in supported by facts. \n", 523 | " Give a binary score 'pass' or 'fail' score to indicate whether the answer is grounded in supported by a \n", 524 | " set of facts in your best knowledge. Provide the binary score as a JSON with a single key 'score' and no preamble or explanation.\n", 525 | " \n", 526 | " Here is the answer: {answer}\"\"\"\n", 527 | " c3_prompt = ChatPromptTemplate.from_template(c3_template)\n", 528 | " \n", 529 | " # Grade by a diff model in this case Claude 3\n", 530 | " #hallucination_grader = prompt | llm_llama31 | JsonOutputParser() \n", 531 | " hallucination_grader = c3_prompt | llm_claude35 | JsonOutputParser()\n", 532 | " score = hallucination_grader.invoke({\"answer\": state['answer'], \"callbacks\": [MyCustomHandler()]})\n", 533 | " return {'answer': score}\n", 534 | "\n", 535 | "def hallucination_grader(state:MultiAgentState):\n", 536 | " c3_template = \"\"\"You are a grader assessing whether an answer is grounded in supported by facts. \n", 537 | " Give a binary score 'pass' or 'fail' score to indicate whether the answer is grounded in supported by a \n", 538 | " set of facts in your best knowledge. Provide the binary score as a JSON with a single key 'score' and no preamble or explanation.\n", 539 | " \n", 540 | " Here is the answer: {answer}\"\"\"\n", 541 | " c3_prompt = ChatPromptTemplate.from_template(c3_template)\n", 542 | " \n", 543 | " # Grade by a diff model in this case Claude 3\n", 544 | " #hallucination_grader = prompt | llm_llama31 | JsonOutputParser() \n", 545 | " hallucination_grader = c3_prompt | llm_claude35 | JsonOutputParser()\n", 546 | " score = hallucination_grader.invoke({\"answer\": state['answer'], \"callbacks\": [MyCustomHandler()]})\n", 547 | " if \"yes\" in score['score'].lower():\n", 548 | " # All documents have been filtered check_relevance\n", 549 | " # We will re-generate a new query\n", 550 | " print(\n", 551 | " \"---DECISION: the answer does not seem to contain hallucination ---\"\n", 552 | " )\n", 553 | " return \"END\"\n", 554 | " else:\n", 555 | " # We have relevant documents, so generate answer\n", 556 | " print(\"---DECISION: the answer migh contain hallucination, next off to human review ---\")\n", 557 | " return \"to_human\"\n", 558 | "\n", 559 | "\n", 560 | "####\n", 561 | "# Extra function but not as a node\n", 562 | "####\n", 563 | "def decide_to_search(state:MultiAgentState):\n", 564 | " \"\"\"\n", 565 | " Determines whether to generate an answer, or add web search\n", 566 | " Args:\n", 567 | " state (dict): The current graph state\n", 568 | " Returns:\n", 569 | " str: Binary decision for next node to call\n", 570 | " \"\"\"\n", 571 | " l31_prompt = PromptTemplate(\n", 572 | " template=\"\"\" <|begin_of_text|><|start_header_id|>system<|end_header_id|> You are a grader assessing whether\n", 573 | " an {answer} is grounded in / relevant to the {question}. Give a binary score 'yes' or 'no' score to indicate\n", 574 | " whether the answer is grounded in / supported by a set of facts. Provide the binary score as a JSON with a\n", 575 | " single key 'score' and no preamble or explanation. <|eot_id|><|start_header_id|>user<|end_header_id|>\n", 576 | " Here is the answer:\n", 577 | " {answer}\n", 578 | " Here is the question: {question} <|eot_id|><|start_header_id|>assistant<|end_header_id|>\"\"\",\n", 579 | " input_variables=[\"question\", \"answer\"],\n", 580 | " )\n", 581 | " \n", 582 | " answer_grader = l31_prompt | llm_llama31 | JsonOutputParser()\n", 583 | " print(\"---ASSESS GRADED ANSWER AGAINST QUESTION---\")\n", 584 | " relevance = answer_grader.invoke({\"answer\": state[\"answer\"], \"question\": state[\"question\"]})\n", 585 | " print(relevance)\n", 586 | " if \"yes\" in relevance['score'].lower():\n", 587 | " # All documents have been filtered check_relevance\n", 588 | " # We will re-generate a new query\n", 589 | " print(\n", 590 | " \"---DECISION: the answer is relevant to the question so it's ready for human review ---\"\n", 591 | " )\n", 592 | " return \"to_human\"\n", 593 | " else:\n", 594 | " # We have relevant documents, so generate answer\n", 595 | " print(\"---DECISION: the answer is NOT relevant to the question then try web search ---\")\n", 596 | " return \"do_search\"\n", 597 | "\n", 598 | "#####\n", 599 | "# text 2 image generation\n", 600 | "####\n", 601 | "def t2i_node2(state:MultiAgentState):\n", 602 | " negative_prompts = [\n", 603 | " \"poorly rendered\",\n", 604 | " \"poor background details\",\n", 605 | " \"poorly drawn objects\",\n", 606 | " \"poorly focused objects\",\n", 607 | " \"disfigured object features\",\n", 608 | " \"cartoon\",]\n", 609 | " body = json.dumps(\n", 610 | " {\n", 611 | " \"taskType\": \"TEXT_IMAGE\",\n", 612 | " \"textToImageParams\": {\n", 613 | " #\"text\":state['answer'].replace(\"{Rewritten Question}:\\n\\n\", \"\")[:510], # Required, Titan image gen v2 limits up to 512 token input\n", 614 | " \"text\":state['question'][:511],\n", 615 | " \"negativeText\": \"poorly rendere, disfigured object features\" #negative_prompts # Optional\n", 616 | " },\n", 617 | " \"imageGenerationConfig\": {\n", 618 | " \"numberOfImages\": 1, # Range: 1 to 5 \n", 619 | " \"quality\": 'premium', # Options: standard or premium\n", 620 | " \"height\": 1024, # Supported height list in the docs \n", 621 | " \"width\": 1024, # Supported width list in the docs\n", 622 | " \"cfgScale\": 6.5, # Range: 1.0 (exclusive) to 10.0\n", 623 | " \"seed\": random.randint(1, 214783647) # Range: 1 to 214783647\n", 624 | " }\n", 625 | " }\n", 626 | " )\n", 627 | "\n", 628 | "\n", 629 | " if os.path.exists(temp_gen_image):\n", 630 | " os.remove(temp_gen_image)\n", 631 | " response = bedrock_client.invoke_model(\n", 632 | " body=body, \n", 633 | " modelId=\"amazon.titan-image-generator-v2:0\",\n", 634 | " accept=\"application/json\", \n", 635 | " contentType=\"application/json\"\n", 636 | " )\n", 637 | " response_body = json.loads(response[\"body\"].read())\n", 638 | " with open(temp_gen_image, 'wb') as file:\n", 639 | " # Decode the base64 data and write it to the file\n", 640 | " file.write(base64.b64decode(response_body[\"images\"][0]))\n", 641 | " return {\"answer\": temp_gen_image}\n", 642 | "\n", 643 | "def t2i_node(state:MultiAgentState):\n", 644 | " url = \"http://video.cavatar.info:8080/generate?prompt=\"\n", 645 | " prompt = f\"Generate a high resiliution, photo realistic picture of {state['question']} with vivid color and attending to details.\"\n", 646 | " response = requests.get(url+prompt)\n", 647 | " if response.status_code == 200:\n", 648 | " image_bytes = response.content\n", 649 | " else:\n", 650 | " print(f\"Error fetching image from {url}\")\n", 651 | " pass\n", 652 | " with open(temp_gen_image, 'wb') as file:\n", 653 | " file.write(image_bytes)\n", 654 | " return {\"answer\": temp_gen_image}" 655 | ] 656 | }, 657 | { 658 | "cell_type": "markdown", 659 | "id": "c1a562f4-ef92-4a7a-859d-4f71b0852c62", 660 | "metadata": {}, 661 | "source": [ 662 | "### Pre-requisite: complete Bedrock agent association with knowledgeBase and lambda function to interact with pre-populated DynamoDB. OBtain the agent ID and Agent \n", 663 | "\n", 664 | "**Please note the next cell might take a few (> 10) minutes to complete**\n", 665 | "\n", 666 | "You might use AWS console (https://us-west-2.console.aws.amazon.com/aos/home?region=us-west-2#opensearch/get-started-serverless) or aws cli to check the status\n", 667 | "* %aws bedrock-agent list-agents\n", 668 | "* %aws bedrock-agent list-agent-aliases --agent-id \\\n", 669 | "* %aws bedrock-agent list-knowledge-bases\n", 670 | "* %aws bedrock-agent list-agent-knowledge-bases" 671 | ] 672 | }, 673 | { 674 | "cell_type": "code", 675 | "execution_count": null, 676 | "id": "e75a6c68-abc5-4d47-89fc-e4c84c15ade7", 677 | "metadata": { 678 | "tags": [] 679 | }, 680 | "outputs": [], 681 | "source": [ 682 | "#%run ./create-agent-with-knowledge-base-and-action-group.ipynb" 683 | ] 684 | }, 685 | { 686 | "cell_type": "markdown", 687 | "id": "016632ba-67e2-4737-a388-7c29e4d2642f", 688 | "metadata": {}, 689 | "source": [ 690 | "### Upon successful complettion of the Amazon Bedrock agent creation, define a node to invoke the agent. " 691 | ] 692 | }, 693 | { 694 | "cell_type": "code", 695 | "execution_count": null, 696 | "id": "e2d8f2f5-be1a-418f-a97d-21446dd51834", 697 | "metadata": { 698 | "tags": [] 699 | }, 700 | "outputs": [], 701 | "source": [ 702 | "####\n", 703 | "# Bedrock agent integration\n", 704 | "####\n", 705 | "import uuid\n", 706 | "import logging\n", 707 | "from datetime import datetime\n", 708 | "%store -r agent_id\n", 709 | "%store -r alias_id\n", 710 | "\n", 711 | "def invoke_BR_agent(agent_id, alias_id, query, enable_trace=False, session_state=dict()):\n", 712 | " session_id = str(uuid.uuid1())\n", 713 | " end_session = False\n", 714 | " logger = logging.getLogger(__name__)\n", 715 | " \n", 716 | " # invoke the agent API\n", 717 | " agentResponse = bedrock_agent_runtime_client.invoke_agent(\n", 718 | " inputText=query,\n", 719 | " agentId=agent_id,\n", 720 | " agentAliasId=alias_id, \n", 721 | " sessionId=session_id,\n", 722 | " enableTrace=enable_trace, \n", 723 | " endSession= end_session,\n", 724 | " sessionState=session_state\n", 725 | " )\n", 726 | " \n", 727 | " if enable_trace:\n", 728 | " logger.info(pprint.pprint(agentResponse))\n", 729 | " \n", 730 | " event_stream = agentResponse['completion']\n", 731 | " try:\n", 732 | " for event in event_stream: \n", 733 | " if 'chunk' in event:\n", 734 | " data = event['chunk']['bytes']\n", 735 | " if enable_trace:\n", 736 | " logger.info(f\"Final answer ->\\n{data.decode('utf8')}\")\n", 737 | " agent_answer = data.decode('utf8')\n", 738 | " end_event_received = True\n", 739 | " return agent_answer\n", 740 | " # End event indicates that the request finished successfully\n", 741 | " elif 'trace' in event:\n", 742 | " if enable_trace:\n", 743 | " logger.info(json.dumps(event['trace'], indent=2))\n", 744 | " else:\n", 745 | " raise Exception(\"unexpected event.\", event)\n", 746 | " except Exception as e:\n", 747 | " raise Exception(\"unexpected event.\", e)\n", 748 | "\n", 749 | "def bedrock_agent_node(state:MultiAgentState):\n", 750 | " today = datetime.today().strftime('%b-%d-%Y')\n", 751 | " session_state = {\n", 752 | " \"promptSessionAttributes\": {\n", 753 | " \"name\": \"John Doe\",\n", 754 | " \"today\": today\n", 755 | " }\n", 756 | " }\n", 757 | " return {'answer': invoke_BR_agent(agent_id, alias_id, state[\"question\"])}" 758 | ] 759 | }, 760 | { 761 | "cell_type": "markdown", 762 | "id": "16a16723-fce4-4601-865c-ded1eab4d224", 763 | "metadata": {}, 764 | "source": [ 765 | "## 4. Defining the Reasoning Flow with LangGraph Nodes and Edges\n", 766 | "\n", 767 | "Implement nodes representing key actions: document retrieval, document grading, web search, and answer generation. Define conditional edges for decision-making: route the question, decide on document relevance, and grade the generated answer. Set up the workflow graph with entry points, nodes, and edges to ensure a logical progression through the RAG agent's steps. LangGraph allows us to define a graph-based workflow for our RAG agent, integrating document retrieval, question routing, answer generation, and self-correction into an efficient pipeline.\n", 768 | "\n", 769 | "Key steps include:\n", 770 | "\n", 771 | "* Question rewrite: Rewrite the query for better intend classification\n", 772 | "* Routing: Deciding whether the question should go to the RAG, LLMs or a web search.\n", 773 | "* Hallucination Grading: Ensuring the generated answer is grounded in the retrieved documents.\n", 774 | "* Human in the loop: In case the answer fall bwloew desired quality, insert human feedback\n", 775 | "\n", 776 | "LangGraph lets us seamlessly integrate these steps into a modular, adaptable workflow, enhancing the agent's ability to handle diverse queries." 777 | ] 778 | }, 779 | { 780 | "cell_type": "code", 781 | "execution_count": null, 782 | "id": "c876613d-860f-4955-bda7-973112beaec6", 783 | "metadata": { 784 | "tags": [] 785 | }, 786 | "outputs": [], 787 | "source": [ 788 | "orch = StateGraph(MultiAgentState)\n", 789 | "orch.add_node(\"rewrite\", rewrite_node)\n", 790 | "orch.add_node(\"router\", router_node)\n", 791 | "orch.add_node('search_expert', search_expert_node)\n", 792 | "orch.add_node('healthcare_expert', rag_node)\n", 793 | "orch.add_node('general_assistant', llm_node)\n", 794 | "orch.add_node('text2image_generation', t2i_node)\n", 795 | "orch.add_node('booking_assistant', bedrock_agent_node)\n", 796 | "orch.add_node('blog_writer', blog_writer_node)\n", 797 | "orch.add_node('human', human_feedback_node)\n", 798 | "#orch.add_node('editor', editor_node)\n", 799 | "\n", 800 | "orch.add_conditional_edges(\n", 801 | " \"router\", \n", 802 | " route_question,\n", 803 | " {'VECTORSTORE': 'healthcare_expert', 'WEBSEARCH': 'search_expert', 'GENERAL': 'general_assistant', \n", 804 | " 'TEXT2IMAGE': 'text2image_generation','BOOKING': 'booking_assistant', 'BLOGWRITER':'blog_writer'}\n", 805 | ")\n", 806 | "\n", 807 | "orch.set_entry_point(\"rewrite\")\n", 808 | "orch.add_edge('rewrite', 'router')\n", 809 | "orch.add_conditional_edges(\n", 810 | " \"healthcare_expert\",\n", 811 | " decide_to_search,\n", 812 | " {\n", 813 | " \"to_human\": \"human\",\n", 814 | " \"do_search\": \"search_expert\",\n", 815 | " },\n", 816 | ")\n", 817 | "#orch.add_edge('search_expert', 'human')\n", 818 | "orch.add_conditional_edges(\n", 819 | " \"search_expert\",\n", 820 | " decide_to_search,\n", 821 | " {\n", 822 | " \"to_human\": \"human\",\n", 823 | " \"do_search\": \"search_expert\",\n", 824 | " },\n", 825 | ")\n", 826 | "orch.add_edge('booking_assistant', END)\n", 827 | "orch.add_conditional_edges(\n", 828 | " \"general_assistant\",\n", 829 | " hallucination_grader,\n", 830 | " {\n", 831 | " \"to_human\": \"human\",\n", 832 | " \"END\": END,\n", 833 | " },\n", 834 | ")\n", 835 | "orch.add_edge('human', END)\n", 836 | "#orch.add_edge('editor', END)\n", 837 | "orch.add_edge('blog_writer', 'text2image_generation')\n", 838 | "orch.add_edge('text2image_generation', END)" 839 | ] 840 | }, 841 | { 842 | "cell_type": "markdown", 843 | "id": "9ea1cbd4-18b1-4ff5-a608-3360ceae31fd", 844 | "metadata": {}, 845 | "source": [ 846 | "## 5. Display the orchestration flows\n", 847 | "\n", 848 | "The orchestration flows can be depicted using the following visual representation that illustrate the sequence of operations, the data transformations, and the control flow between the different modules or algorithms involved in the vision comprehension process. By providing a clear and concise visual representation of the orchestration, it becomes easier for developers, researchers, and stakeholders to understand the overall architecture, identify potential bottlenecks or optimization opportunities, and communicate the system's functionality and performance." 849 | ] 850 | }, 851 | { 852 | "cell_type": "code", 853 | "execution_count": null, 854 | "id": "2c29c4f0-e832-4df5-89cf-b6d7b1907350", 855 | "metadata": { 856 | "scrolled": true, 857 | "tags": [] 858 | }, 859 | "outputs": [], 860 | "source": [ 861 | "from IPython.display import Image, display\n", 862 | "from langchain_core.runnables.graph import CurveStyle, MermaidDrawMethod #, NodeColors\n", 863 | "\n", 864 | "graph = orch.compile(checkpointer=memory, interrupt_before = ['human'])\n", 865 | "display(Image(graph.get_graph().draw_mermaid_png(\n", 866 | " curve_style=CurveStyle.LINEAR,\n", 867 | " #node_colors=NodeColors(start=\"#ffdfba\", end=\"#baffc9\", other=\"#fad7de\"),\n", 868 | " #node_styles=custom_node_style,\n", 869 | " wrap_label_n_words=9,\n", 870 | " output_file_path=None,\n", 871 | " draw_method=MermaidDrawMethod.API,\n", 872 | " background_color=\"white\",\n", 873 | " padding=20,\n", 874 | ")))" 875 | ] 876 | }, 877 | { 878 | "cell_type": "markdown", 879 | "id": "4cdc6137-3a75-43a9-bae0-64c6dc95ad2a", 880 | "metadata": {}, 881 | "source": [ 882 | "## 6. Execute this orchestration pipeline with query driven reasoning \n", 883 | "\n", 884 | "Executing agentic services with multi-agent capability on executing a pipeline with query-driven reasoning and reactions involves the development of a system that can autonomously perform tasks and make decisions based on the information it gathers and the queries it receives. This system would consist of multiple intelligent agents, each with its own set of capabilities and knowledge, working together to achieve a common goal. The agents would use query-driven reasoning to understand the user's intent and then react accordingly, executing the necessary steps in the pipeline to provide the desired outcome. This approach allows for a more dynamic and adaptive system that can handle a wide range of tasks and respond to changing conditions in real-time. The result is a powerful and flexible service that can assist users with a variety of needs, from information retrieval to complex problem-solving." 885 | ] 886 | }, 887 | { 888 | "cell_type": "code", 889 | "execution_count": null, 890 | "id": "a9c57a06-d95b-460b-aa7f-ab3a5a7c6ebc", 891 | "metadata": { 892 | "tags": [] 893 | }, 894 | "outputs": [], 895 | "source": [ 896 | "from PIL import Image\n", 897 | "\n", 898 | "thread = {\"configurable\": {\"thread_id\": \"42\", \"recursion_limit\": 10}}\n", 899 | "results = []\n", 900 | "prompts =[\n", 901 | " \"Under what circumstances a patient should be screened for ectopic ACTH syndrome(EAS)?\", # Use native RAG then human review if needed\n", 902 | " \"What could be the typical clinical symptoms of Blepharitis?\", # First try native RAG but not found then try Web search hen human review if needed\n", 903 | " \"How many total medals did the US Olympic Team won in Paris 2024?\", # Use Web search hen human review if needed\n", 904 | " \"Why Steve Jobs was considered a legent in the tech world?\", # Combine the answers from 2 LLMs then human review if needed\n", 905 | " \"Generate a high res image of a colorful macaw reasting on tree, with vivid color and attending to details.\", # Use text-2-image generation \n", 906 | " \"Hi, I want to create a booking for 2 people, at 8pm on the 5th of May 2024.\", #Use Bedrock agent to interact with KnowledgeBase and DynamoDB\n", 907 | " \"Write a blog post about the 2024 uncrewed return of the Starliner space capsule pending safety concerns and helium leaks. Explain how NASA plans to safely return the two stranded astronauts. \\\n", 908 | " If possible, please include date, figures and locations associated with the topic\" #Blog writting using CrewAI\n", 909 | " ]\n", 910 | "\n", 911 | "for prompt in prompts:\n", 912 | " for event in graph.stream({'question':prompt,}, thread):\n", 913 | " print(event)\n", 914 | " results.append(event)\n", 915 | " if os.path.exists(temp_gen_image):\n", 916 | " Image.open(temp_gen_image).show()\n", 917 | " print(\"\\n\\n---------------------------------------\\n\\n\")" 918 | ] 919 | }, 920 | { 921 | "cell_type": "markdown", 922 | "id": "d5bf6b07-6780-477f-b3d1-c80410e42469", 923 | "metadata": {}, 924 | "source": [ 925 | "#### (Optional) Display the generated blog" 926 | ] 927 | }, 928 | { 929 | "cell_type": "code", 930 | "execution_count": null, 931 | "id": "6b71dd8c-1657-4305-adb3-a5f97cd59edd", 932 | "metadata": {}, 933 | "outputs": [], 934 | "source": [ 935 | "from IPython.display import Markdown\n", 936 | "with open(markdown_filename, 'r') as file:\n", 937 | " readme_content = file.read()\n", 938 | "\n", 939 | "# Display the contents as Markdown\n", 940 | "display(Markdown(readme_content))" 941 | ] 942 | }, 943 | { 944 | "cell_type": "markdown", 945 | "id": "ca8e30c0-fba5-4359-a453-79d1fe3c92ae", 946 | "metadata": {}, 947 | "source": [ 948 | "### Next Steps:\n", 949 | "\n", 950 | "1. Planning\n", 951 | "2. Colaborative multi-agent reasoning\n", 952 | "3. Momeory for multi-round and personalize reasoning\n", 953 | "4. While this simple search-strategy shows a meaningful improvement in the success rate, it still struggles on long horizon tasks due to sparsity of environment rewards.\n", 954 | "5. To combine a planning and reasoning agent with MCTS inference-time search and AI self-critique for self-supervised data collection, which we then use for RL type training." 955 | ] 956 | }, 957 | { 958 | "cell_type": "markdown", 959 | "id": "ebdb6d0c-cb23-4462-986a-b68c8dacf88e", 960 | "metadata": {}, 961 | "source": [ 962 | "## 4. Clean-up¶\n", 963 | "Let's delete all the associated resources created to avoid unnecessary costs. Please change the markdown to code before encuring the cells" 964 | ] 965 | }, 966 | { 967 | "cell_type": "code", 968 | "execution_count": null, 969 | "id": "98627c16-c21f-429f-aef9-c486930f6dff", 970 | "metadata": {}, 971 | "outputs": [], 972 | "source": [ 973 | "from agent import clean_up_resources, delete_agent_roles_and_policies" 974 | ] 975 | }, 976 | { 977 | "cell_type": "code", 978 | "execution_count": null, 979 | "id": "05e48afd-48d2-464e-9fac-9510f94a3241", 980 | "metadata": {}, 981 | "outputs": [], 982 | "source": [ 983 | "%store -r table_name\n", 984 | "%store -r lambda_function\n", 985 | "%store -r lambda_function_name\n", 986 | "%store -r agent_action_group_response\n", 987 | "%store -r agent_functions\n", 988 | "%store -r alias_id\n", 989 | "%store -r agent_id\n", 990 | "%store -r agent_name\n", 991 | "%store -r kb_id\n", 992 | "%store -r alias_id\n", 993 | "\n", 994 | "# Delete resources including Lambda, Dynamo and agent\n", 995 | "clean_up_resources(\n", 996 | " table_name, lambda_function, lambda_function_name, agent_action_group_response, agent_functions, agent_id, kb_id, alias_id\n", 997 | ")\n", 998 | "\n", 999 | "# Delete the agent roles and policies\n", 1000 | "delete_agent_roles_and_policies(agent_name)" 1001 | ] 1002 | }, 1003 | { 1004 | "cell_type": "code", 1005 | "execution_count": null, 1006 | "id": "bf04e4ad-f2ab-4867-876e-1143139f75f4", 1007 | "metadata": {}, 1008 | "outputs": [], 1009 | "source": [ 1010 | "## __End" 1011 | ] 1012 | } 1013 | ], 1014 | "metadata": { 1015 | "kernelspec": { 1016 | "display_name": "medf", 1017 | "language": "python", 1018 | "name": "medf" 1019 | }, 1020 | "language_info": { 1021 | "codemirror_mode": { 1022 | "name": "ipython", 1023 | "version": 3 1024 | }, 1025 | "file_extension": ".py", 1026 | "mimetype": "text/x-python", 1027 | "name": "python", 1028 | "nbconvert_exporter": "python", 1029 | "pygments_lexer": "ipython3", 1030 | "version": "3.11.5" 1031 | } 1032 | }, 1033 | "nbformat": 4, 1034 | "nbformat_minor": 5 1035 | } 1036 | --------------------------------------------------------------------------------