├── FedRLHF-problem.pdf ├── FedRLHF-problem.png ├── .gitignore ├── exp-LLM-IMDB ├── start_multiple_clients.bash ├── README.md ├── config.py ├── start_client.py ├── plot_combined_performance.py ├── environment.yml ├── server.py ├── centralized_training.py └── client.py └── README.md /FedRLHF-problem.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flint-xf-fan/Federated-RLHF/HEAD/FedRLHF-problem.pdf -------------------------------------------------------------------------------- /FedRLHF-problem.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flint-xf-fan/Federated-RLHF/HEAD/FedRLHF-problem.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | datasets 3 | __pycache__ 4 | */__pycache_ 5 | */trained_models 6 | exp-LLM-IMDB/evaluation_logs/ 7 | exp-LLM-IMDB/training_logs/ 8 | 9 | -------------------------------------------------------------------------------- /exp-LLM-IMDB/start_multiple_clients.bash: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Read num_clients from config.py using Python 4 | NUM_CLIENTS=$(python -c 'import config; print(config.NUM_CLIENTS)') 5 | 6 | # Loop to start multiple clients 7 | for (( i=0; i<$NUM_CLIENTS; i++ )) 8 | do 9 | echo "Starting client $i" 10 | python start_client.py $i & 11 | done 12 | 13 | # Wait for all background processes to finish 14 | wait 15 | 16 | echo "All clients have finished" -------------------------------------------------------------------------------- /exp-LLM-IMDB/README.md: -------------------------------------------------------------------------------- 1 | # FedRLHF 2 | ## Task: IMDB 3 | pass 4 | ## Description 5 | pass 6 | ## TODO 7 | 8 | ## Set up 9 | ``` 10 | conda env create -f environment.yml 11 | conda activate fedrlhf 12 | ``` 13 | 14 | ## Running Experiments 15 | first, ensure `num_clients` are set to the same number in both `main.py` and `run_multiple_clients.sh` 16 | 17 | then, start the FedRLHF server by running 18 | 19 | ``` 20 | python server.py 21 | ``` 22 | 23 | in another terminal, run the bash script to simulate multiple clients connecting to the server via gRPC protocol (check and verify the port number is not in use; default is 8080): 24 | 25 | 26 | ``` 27 | bash run_multiple_clients.sh 28 | ``` -------------------------------------------------------------------------------- /exp-LLM-IMDB/config.py: -------------------------------------------------------------------------------- 1 | # config.py 2 | 3 | # Global configuration variables 4 | 5 | # Number of federated learning rounds 6 | NUM_ROUNDS = 5 # Adjust as needed 7 | 8 | # Number of clients participating in the federated learning 9 | NUM_CLIENTS = 5 # Adjust as needed 10 | 11 | # LAMBDA_LM = 0.5 12 | 13 | # Generate LAMBDA_LMs dynamically 14 | LAMBDA_LMs = {i: 0.1 + (0.8 * i / (NUM_CLIENTS - 1)) for i in range(NUM_CLIENTS)} 15 | 16 | DATASET_DIVISION = 1 # Adjust as needed 17 | 18 | # Flower server mode -- "server" or "simulation" 19 | FLOWER_SIM_MODE = "server" 20 | # FLOWER_SIM_MODE = "simulation" 21 | 22 | SEED = 42 23 | 24 | # Config for the trainer 25 | TRAINER_CONFIG = { 26 | "BATCH_SIZE": 16, 27 | "MINI_BATCH_SIZE": 16, 28 | "LEARNING_RATE": 1e-5, 29 | "num_epochs": 5, 30 | "num_warmup_steps": 100, 31 | "num_training_steps": 1000, 32 | } 33 | 34 | VERBOSE = False 35 | -------------------------------------------------------------------------------- /exp-LLM-IMDB/start_client.py: -------------------------------------------------------------------------------- 1 | # start_client.py 2 | 3 | import flwr as fl 4 | from client import FedRLHFClient 5 | import sys 6 | from config import NUM_CLIENTS, NUM_ROUNDS, LAMBDA_LMs 7 | 8 | def main(client_id: int, num_clients: int, num_rounds: int, lambda_lm: float) -> None: 9 | # Initialize your client with the new arguments 10 | client = FedRLHFClient(client_id, num_clients, num_rounds, lambda_lm) 11 | 12 | # Start Flower client 13 | fl.client.start_client( 14 | server_address="127.0.0.1:8080", 15 | client=client, 16 | grpc_max_message_length=1024 * 1024 * 1024 # 1 GB 17 | ) 18 | 19 | if __name__ == "__main__": 20 | if len(sys.argv) != 2: 21 | print("Usage: python start_client.py ") 22 | sys.exit(1) 23 | client_id = int(sys.argv[1]) 24 | client_lambda_lm = LAMBDA_LMs[client_id] 25 | main(client_id, NUM_CLIENTS, NUM_ROUNDS, client_lambda_lm) 26 | -------------------------------------------------------------------------------- /exp-LLM-IMDB/plot_combined_performance.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import json 3 | import os 4 | import numpy as np 5 | from config import NUM_CLIENTS 6 | import seaborn as sns 7 | 8 | def load_json(filename): 9 | with open(filename, 'r') as f: 10 | return json.load(f) 11 | 12 | def interpolate_data(x, y, new_x): 13 | return np.interp(new_x, x, y) 14 | 15 | def plot_combined_performance(): 16 | sns.set_theme(style="darkgrid") 17 | fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 8), sharex=True) 18 | 19 | # Load centralized data 20 | cent_data = load_json('metrics/metrics_centralized.json') 21 | cent_samples = np.array(cent_data['total_samples']) 22 | cent_rewards = np.array(cent_data['avg_rewards']) 23 | cent_losses = np.array(cent_data['avg_losses']) 24 | 25 | # Load and aggregate client data 26 | client_data = [load_json(f'metrics/metrics_client_{i}.json') for i in range(NUM_CLIENTS)] 27 | fed_samples = np.array(client_data[0]['total_samples']) # Assuming all clients have same sample points 28 | fed_rewards = np.array([data['avg_rewards'] for data in client_data]) 29 | fed_losses = np.array([data['losses'] for data in client_data]) 30 | 31 | # Calculate mean and std of federated rewards and losses 32 | fed_mean_rewards = np.mean(fed_rewards, axis=0) 33 | fed_std_rewards = np.std(fed_rewards, axis=0) 34 | fed_mean_losses = np.mean(fed_losses, axis=0) 35 | fed_std_losses = np.std(fed_losses, axis=0) 36 | 37 | # Interpolate centralized data to match federated sample points 38 | cent_rewards_interp = interpolate_data(cent_samples, cent_rewards, fed_samples) 39 | cent_losses_interp = interpolate_data(cent_samples, cent_losses, fed_samples) 40 | 41 | # Plotting Rewards 42 | ax1.plot(fed_samples, cent_rewards_interp, label='Centralized RLHF', color='blue') 43 | ax1.plot(fed_samples, fed_mean_rewards, label=f'FedRLHF (K={NUM_CLIENTS})', color='red') 44 | ax1.fill_between(fed_samples, fed_mean_rewards - fed_std_rewards, fed_mean_rewards + fed_std_rewards, alpha=0.3, color='red') 45 | ax1.set_ylabel('Average Reward',fontsize=16) 46 | # ax1.set_title('Rewards Comparison: Centralized vs Federated',fontsize=14) 47 | ax1.legend(fontsize=16) 48 | ax1.grid(True) 49 | 50 | # Plotting Losses 51 | ax2.plot(fed_samples, cent_losses_interp, label='Centralized RLHF', color='blue') 52 | ax2.plot(fed_samples, fed_mean_losses, label=f'FedRLHF (K={NUM_CLIENTS})', color='red') 53 | ax2.fill_between(fed_samples, fed_mean_losses - fed_std_losses, fed_mean_losses + fed_std_losses, alpha=0.3, color='red') 54 | ax2.set_xlabel('Total Samples',fontsize=16) 55 | ax2.set_ylabel('Average Loss',fontsize=16) 56 | # ax2.set_title('Losses Comparison: Centralized vs Federated',fontsize=14) 57 | ax2.legend(fontsize=16) 58 | ax2.grid(True) 59 | 60 | # Adjust x-axis to start from 0 61 | ax1.set_xlim(0, max(fed_samples)) 62 | ax2.set_xlim(0, max(fed_samples)) 63 | 64 | plt.tight_layout() 65 | 66 | # Save the plot 67 | os.makedirs("training_logs", exist_ok=True) 68 | plt.savefig(f"training_logs/performance_comparison_k{NUM_CLIENTS}.pdf", dpi=300, bbox_inches='tight') 69 | plt.close() 70 | 71 | print(f"Combined performance plot saved as 'training_logs/performance_comparison_k{NUM_CLIENTS}.pdf'") 72 | 73 | # Additional analysis 74 | print("\nPerformance Analysis:") 75 | print(f"Centralized final reward: {cent_rewards[-1]:.4f}") 76 | print(f"Federated final mean reward: {fed_mean_rewards[-1]:.4f}") 77 | print(f"Federated final reward std: {fed_std_rewards[-1]:.4f}") 78 | print(f"Centralized final loss: {cent_losses[-1]:.4f}") 79 | print(f"Federated final mean loss: {fed_mean_losses[-1]:.4f}") 80 | print(f"Federated final loss std: {fed_std_losses[-1]:.4f}") 81 | 82 | if __name__ == "__main__": 83 | plot_combined_performance() -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FedRLHF: A Convergence-Guaranteed Framework for Privacy-Preserving and Personalized RLHF 2 | 3 | [![AAMAS 2025](https://img.shields.io/badge/AAMAS-2025-blue)](https://aamas2025.com) 4 | [![GitHub Repo stars](https://img.shields.io/github/stars/flint-xf-fan/Federated-RLHF?style=social)](https://github.com/flint-xf-fan/Federated-RLHF) 5 | 6 | 7 | [![FedRLHF](FedRLHF-problem.png)](FedRLHF-problem.pdf) 8 | 9 | ### (updating...) Code release progress 10 | ⚠️ we are still testing the setup & configuration script. we will update this page once everything is ready ⚠️ (expected to be completed around the conference dates ie. in late May) 11 | 12 | - [x] [IMDB experiments] code setup and scripts uploaded 13 | - [ ] [IMDB experiments] code clean up, test and detailed instructions 14 | - [ ] [MovieLens experiments] code setup and scripts upload 15 | - [ ] [MovieLens experiments] code clean up, test and detailed instructions 16 | 17 | ## Overview 18 | This repository contains the official implementation of **FedRLHF: A Convergence-Guaranteed Federated Framework for Privacy-Preserving and Personalized Reinforcement Learning with Human Feedback**, as presented at **AAMAS 2025**. 19 | 20 | FedRLHF combines federated learning principles with reinforcement learning from human feedback (RLHF) to provide: 21 | 22 | - **Privacy-Preserving Training**: Securely train models without sharing raw data between clients. 23 | - **Personalized Reinforcement Learning**: Incorporate human feedback to personalize the policy. 24 | - **Convergence Guarantees**: Rigorous proofs for convergence under federated settings. 25 | 26 | --- 27 | 28 | ## Features 29 | - **Two Benchmark Tasks**: 30 | - IMDb (Sentiment Analysis + Reward Modeling) 31 | - MovieLens (Recommendation Systems with Federated Policies) 32 | - **Components**: 33 | - Federated Server and Client implementations. 34 | - Reward Modeling and Policy Optimization. 35 | - Tools for Visualization and Performance Analysis. 36 | - **Live Demo**: Showcase of personalization using pre-trained models (planned). 37 | 38 | --- 39 | 40 | ## Repository Structure 41 | 42 | ```plaintext 43 | FedRLHF/ 44 | ├── IMDb/ # IMDb-based federated RLHF task 45 | │ ├── centralized_training.py 46 | │ ├── server.py 47 | │ ├── client.py 48 | │ ├── config.py 49 | │ ├── plot_combined_performance.py 50 | │ ├── visualize_rewards_trends.py 51 | │ ├── req.txt # Dependencies 52 | │ └── start_multiple_clients.bash 53 | ├── MovieLens/ # MovieLens-based federated RLHF task 54 | │ ├── fed_rlhf/ # Core FedRLHF Implementation 55 | │ │ ├── server.py 56 | │ │ └── client.py 57 | │ ├── utils/ # Utilities for metrics and visualization 58 | │ ├── models/ # Reward and Base Models 59 | │ ├── data/ # Dataset loading 60 | │ ├── environment.yml # Conda environment setup 61 | │ └── main.py # Entry point for MovieLens experiments 62 | └── README.md 63 | ``` 64 | 65 | --- 66 | 67 | ## Getting Started 68 | 69 | ### Prerequisites 70 | 71 | Ensure you have the following installed: 72 | - Python >= 3.8 73 | - Conda or Virtualenv (for environment setup) 74 | - Flower 75 | - (updating..) 76 | 77 | - 78 | ### Setup 79 | 80 | 1. **Clone the Repository**: 81 | ```bash 82 | git clone https://github.com/flint-xf-fan/Federated-RLHF.git 83 | cd Federated-RLHF 84 | ``` 85 | 86 | 2. **Install Dependencies**: 87 | For IMDb: 88 | ```bash 89 | pip install -r IMDb/req.txt 90 | ``` 91 | For MovieLens: 92 | ```bash 93 | conda env create -f MovieLens/environment.yml 94 | conda activate fedrlhf 95 | ``` 96 | 97 | ### Running Experiments 98 | 99 | #### IMDb 100 | TODO 101 | 102 | #### MovieLens 103 | TODO 104 | 105 | --- 106 | 107 | ## Citation 108 | If you use this code in your research, please cite the following paper (preprint): 109 | 110 | ```bibtex 111 | @article{fan2024fedrlhf, 112 | title={FedRLHF: A Convergence-Guaranteed Federated Framework for Privacy-Preserving and Personalized RLHF}, 113 | author={Fan, Flint Xiaofeng and Tan, Cheston and Ong, Yew-Soon and Wattenhofer, Roger and Ooi, Wei-Tsang}, 114 | journal={arXiv preprint arXiv:2412.15538}, 115 | year={2024} 116 | } 117 | ``` 118 | 119 | TODO: replace the preprint with aamas version. 120 | 121 | --- 122 | 123 | ## License 124 | See the aamas license documentation. 125 | 126 | --- 127 | 128 | ## Contact 129 | For questions or collaboration inquiries, please contact: 130 | - **Name**: Flint 131 | - **Email**: fxf@u.nus.edu 132 | 133 | --- 134 | 135 | Enjoy using FedRLHF! 136 | -------------------------------------------------------------------------------- /exp-LLM-IMDB/environment.yml: -------------------------------------------------------------------------------- 1 | name: fedrlhf 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=main 8 | - _openmp_mutex=5.1=1_gnu 9 | - blas=1.0=mkl 10 | - bottleneck=1.3.7=py39ha9d4c09_0 11 | - brotli=1.0.9=h5eee18b_8 12 | - brotli-bin=1.0.9=h5eee18b_8 13 | - brotli-python=1.0.9=py39h6a678d5_8 14 | - bzip2=1.0.8=h5eee18b_6 15 | - ca-certificates=2024.7.2=h06a4308_0 16 | - certifi=2024.7.4=py39h06a4308_0 17 | - charset-normalizer=3.3.2=pyhd3eb1b0_0 18 | - cuda-cudart=11.8.89=0 19 | - cuda-cupti=11.8.87=0 20 | - cuda-libraries=11.8.0=0 21 | - cuda-nvrtc=11.8.89=0 22 | - cuda-nvtx=11.8.86=0 23 | - cuda-runtime=11.8.0=0 24 | - cuda-version=12.6=3 25 | - ffmpeg=4.3=hf484d3e_0 26 | - filelock=3.13.1=py39h06a4308_0 27 | - freetype=2.12.1=h4a9f257_0 28 | - gmp=6.2.1=h295c915_3 29 | - gmpy2=2.1.2=py39heeb90bb_0 30 | - gnutls=3.6.15=he1e5248_0 31 | - idna=3.7=py39h06a4308_0 32 | - importlib_resources=6.4.0=py39h06a4308_0 33 | - intel-openmp=2023.1.0=hdb19cb5_46306 34 | - jinja2=3.1.4=py39h06a4308_0 35 | - jpeg=9e=h5eee18b_3 36 | - lame=3.100=h7b6447c_0 37 | - lcms2=2.12=h3be6417_0 38 | - ld_impl_linux-64=2.38=h1181459_1 39 | - lerc=3.0=h295c915_0 40 | - libbrotlicommon=1.0.9=h5eee18b_8 41 | - libbrotlidec=1.0.9=h5eee18b_8 42 | - libbrotlienc=1.0.9=h5eee18b_8 43 | - libcublas=11.11.3.6=0 44 | - libcufft=10.9.0.58=0 45 | - libcufile=1.11.1.6=0 46 | - libcurand=10.3.7.68=0 47 | - libcusolver=11.4.1.48=0 48 | - libcusparse=11.7.5.86=0 49 | - libdeflate=1.17=h5eee18b_1 50 | - libffi=3.4.4=h6a678d5_1 51 | - libgcc-ng=11.2.0=h1234567_1 52 | - libgomp=11.2.0=h1234567_1 53 | - libiconv=1.16=h5eee18b_3 54 | - libidn2=2.3.4=h5eee18b_0 55 | - libjpeg-turbo=2.0.0=h9bf148f_0 56 | - libnpp=11.8.0.86=0 57 | - libnvjpeg=11.9.0.86=0 58 | - libpng=1.6.39=h5eee18b_0 59 | - libstdcxx-ng=11.2.0=h1234567_1 60 | - libtasn1=4.19.0=h5eee18b_0 61 | - libtiff=4.5.1=h6a678d5_0 62 | - libunistring=0.9.10=h27cfd23_0 63 | - libwebp-base=1.3.2=h5eee18b_0 64 | - llvm-openmp=14.0.6=h9e868ea_0 65 | - lz4-c=1.9.4=h6a678d5_1 66 | - markupsafe=2.1.3=py39h5eee18b_0 67 | - matplotlib-base=3.9.2=py39hbfdbfaf_0 68 | - mkl=2023.1.0=h213fc3f_46344 69 | - mkl-service=2.4.0=py39h5eee18b_1 70 | - mkl_fft=1.3.8=py39h5eee18b_0 71 | - mkl_random=1.2.4=py39hdb19cb5_0 72 | - mpc=1.1.0=h10f8cd9_1 73 | - mpfr=4.0.2=hb69a4c5_1 74 | - mpmath=1.3.0=py39h06a4308_0 75 | - ncurses=6.4=h6a678d5_0 76 | - nettle=3.7.3=hbbd107a_1 77 | - networkx=3.2.1=py39h06a4308_0 78 | - numexpr=2.8.7=py39h85018f9_0 79 | - numpy=1.26.4=py39h5f9d8c6_0 80 | - numpy-base=1.26.4=py39hb5e798b_0 81 | - openh264=2.1.1=h4ff587b_0 82 | - openjpeg=2.5.2=he7f1fd0_0 83 | - openssl=3.0.14=h5eee18b_0 84 | - packaging=24.1=py39h06a4308_0 85 | - pandas=2.2.2=py39h6a678d5_0 86 | - pillow=10.4.0=py39h5eee18b_0 87 | - pip=24.2=py39h06a4308_0 88 | - pysocks=1.7.1=py39h06a4308_0 89 | - python=3.9.19=h955ad1f_1 90 | - python-dateutil=2.9.0post0=py39h06a4308_2 91 | - python-tzdata=2023.3=pyhd3eb1b0_0 92 | - pytorch=2.4.0=py3.9_cuda11.8_cudnn9.1.0_0 93 | - pytorch-cuda=11.8=h7e8668a_5 94 | - pytorch-mutex=1.0=cuda 95 | - pytz=2024.1=py39h06a4308_0 96 | - pyyaml=6.0.1=py39h5eee18b_0 97 | - readline=8.2=h5eee18b_0 98 | - requests=2.32.3=py39h06a4308_0 99 | - seaborn=0.13.2=py39h06a4308_0 100 | - setuptools=72.1.0=py39h06a4308_0 101 | - six=1.16.0=pyhd3eb1b0_1 102 | - sqlite=3.45.3=h5eee18b_0 103 | - sympy=1.13.2=py39h06a4308_0 104 | - tbb=2021.8.0=hdb19cb5_0 105 | - tk=8.6.14=h39e8969_0 106 | - torchaudio=2.4.0=py39_cu118 107 | - torchtriton=3.0.0=py39 108 | - torchvision=0.19.0=py39_cu118 109 | - typing_extensions=4.11.0=py39h06a4308_0 110 | - unicodedata2=15.1.0=py39h5eee18b_0 111 | - urllib3=2.2.2=py39h06a4308_0 112 | - wheel=0.43.0=py39h06a4308_0 113 | - xz=5.4.6=h5eee18b_1 114 | - yaml=0.2.5=h7b6447c_0 115 | - zlib=1.2.13=h5eee18b_1 116 | - zstd=1.5.5=hc292b87_2 117 | - pip: 118 | - accelerate==0.33.0 119 | - aiohappyeyeballs==2.4.0 120 | - aiohttp==3.10.5 121 | - aiosignal==1.3.1 122 | - async-timeout==4.0.3 123 | - attrs==24.2.0 124 | - cffi==1.17.0 125 | - click==8.1.7 126 | - colorama==0.4.6 127 | - contourpy==1.3.0 128 | - cryptography==42.0.8 129 | - cycler==0.12.1 130 | - datasets==2.21.0 131 | - dill==0.3.8 132 | - docker-pycreds==0.4.0 133 | - docstring-parser==0.16 134 | - eval-type-backport==0.2.0 135 | - flwr==1.11.0 136 | - fonttools==4.53.1 137 | - frozenlist==1.4.1 138 | - fsspec==2024.6.1 139 | - gitdb==4.0.11 140 | - gitpython==3.1.43 141 | - grpcio==1.66.1 142 | - huggingface-hub==0.24.6 143 | - importlib-resources==6.4.4 144 | - iterators==0.0.2 145 | - joblib==1.4.2 146 | - kiwisolver==1.4.5 147 | - markdown-it-py==3.0.0 148 | - mdurl==0.1.2 149 | - multidict==6.0.5 150 | - multiprocess==0.70.16 151 | - pathspec==0.12.1 152 | - platformdirs==4.2.2 153 | - protobuf==4.25.4 154 | - psutil==6.0.0 155 | - pyarrow==17.0.0 156 | - pycparser==2.22 157 | - pycryptodome==3.20.0 158 | - pygments==2.18.0 159 | - pyparsing==3.1.4 160 | - regex==2024.7.24 161 | - rich==13.8.0 162 | - safetensors==0.4.4 163 | - scikit-learn==1.5.1 164 | - scikit-surprise==1.1.4 165 | - scipy==1.13.1 166 | - sentry-sdk==2.13.0 167 | - setproctitle==1.3.3 168 | - shellingham==1.5.4 169 | - shtab==1.7.1 170 | - smmap==5.0.1 171 | - threadpoolctl==3.5.0 172 | - tokenizers==0.19.1 173 | - tomli==2.0.1 174 | - tomli-w==1.0.0 175 | - tqdm==4.66.5 176 | - transformers==4.44.2 177 | - trl==0.10.1 178 | - typer==0.9.4 179 | - tyro==0.8.10 180 | - tzdata==2024.1 181 | - wandb==0.17.8 182 | - xxhash==3.5.0 183 | - yarl==1.9.7 184 | - zipp==3.20.1 185 | prefix: /home/flint/anaconda3/envs/fedrlhf 186 | -------------------------------------------------------------------------------- /exp-LLM-IMDB/server.py: -------------------------------------------------------------------------------- 1 | # server.py 2 | 3 | import os 4 | import flwr as fl 5 | from flwr.server.strategy import FedAvg 6 | from typing import List, Tuple, Dict, Optional 7 | from flwr.common import Parameters, FitRes, Status, Code 8 | from flwr.server.client_proxy import ClientProxy 9 | from collections import defaultdict 10 | import numpy as np 11 | import matplotlib.pyplot as plt 12 | import json 13 | import seaborn as sns 14 | sns.set_theme(style="darkgrid") 15 | 16 | from config import NUM_CLIENTS, NUM_ROUNDS, FLOWER_SIM_MODE, LAMBDA_LMs 17 | 18 | # Custom FedAvg strategy with overridden aggregate_fit 19 | class CustomFedAvg(FedAvg): 20 | def __init__(self, *args, **kwargs): 21 | super().__init__(*args, **kwargs) 22 | self.client_metrics = defaultdict(list) 23 | self.global_metrics = defaultdict(list) 24 | 25 | def aggregate_fit( 26 | self, 27 | server_round: int, 28 | results: List[Tuple[ClientProxy, FitRes]], 29 | failures: List[BaseException], 30 | ) -> Tuple[Optional[Parameters], Dict[str, float]]: 31 | # Use the superclass method to aggregate parameters 32 | aggregated_params, aggregated_metrics = super().aggregate_fit(server_round, results, failures) 33 | 34 | if aggregated_params is not None: 35 | # Process client-specific metrics 36 | total_steps = 0 37 | total_samples = 0 38 | rewards = [] 39 | losses = [] 40 | for _, fit_res in results: 41 | client_id = fit_res.metrics["client_id"] 42 | rewards.append(fit_res.metrics["avg_reward"]) 43 | losses.append(fit_res.metrics["avg_loss"]) 44 | total_steps += fit_res.metrics["total_steps"] 45 | total_samples += fit_res.metrics["total_samples"] 46 | 47 | self.client_metrics[client_id].append({ 48 | "round": server_round, 49 | "avg_reward": fit_res.metrics["avg_reward"], 50 | "avg_loss": fit_res.metrics["avg_loss"], 51 | "total_steps": fit_res.metrics["total_steps"], 52 | "total_samples": fit_res.metrics["total_samples"], 53 | }) 54 | 55 | # Calculate simple averages 56 | avg_reward = sum(rewards) / len(rewards) if rewards else 0.0 57 | avg_loss = sum(losses) / len(losses) if losses else 0.0 58 | 59 | # Save metrics 60 | self.global_metrics["rounds"].append(server_round) 61 | self.global_metrics["avg_reward"].append(avg_reward) 62 | self.global_metrics["avg_loss"].append(avg_loss) 63 | self.global_metrics["total_steps"].append(total_steps) 64 | self.global_metrics["total_samples"].append(total_samples) 65 | 66 | print(f"Round {server_round} - Global Metrics:") 67 | print(f" Average Reward: {avg_reward:.4f}") 68 | print(f" Average Loss: {avg_loss:.4f}") 69 | print(f" Total Steps: {total_steps}") 70 | print(f" Total Samples: {total_samples}") 71 | 72 | # Generate and save the visualization 73 | self.visualize_metrics() 74 | 75 | return aggregated_params, aggregated_metrics 76 | 77 | def visualize_metrics(self): 78 | rounds = range(1, len(self.global_metrics["avg_reward"]) + 1) 79 | fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 10)) # Increased figure height 80 | 81 | # Rewards subplot 82 | for client_id, metrics in self.client_metrics.items(): 83 | client_rounds = [m['round'] for m in metrics] 84 | rewards = [m['avg_reward'] for m in metrics] 85 | ax1.plot(client_rounds, rewards, marker='o', label=f'Client {client_id} Reward') 86 | ax1.plot(rounds, self.global_metrics["avg_reward"], marker='s', linewidth=2, linestyle='--', label='Global Avg Reward') 87 | # ax1.set_title("Per-Client and Global Average Rewards") 88 | ax1.set_xlabel("Round",fontsize=16) 89 | ax1.set_ylabel("Reward",fontsize=16) 90 | 91 | # Losses subplot 92 | for client_id, metrics in self.client_metrics.items(): 93 | client_rounds = [m['round'] for m in metrics] 94 | losses = [m['avg_loss'] for m in metrics] 95 | ax2.plot(client_rounds, losses, marker='o', label=f'Client {client_id} Loss') 96 | ax2.plot(rounds, self.global_metrics["avg_loss"], marker='s', linewidth=2, linestyle='--', label='Global Avg Loss') 97 | # ax2.set_title("Per-Client and Global Average Losses") 98 | ax2.set_xlabel("Round",fontsize=16) 99 | ax2.set_ylabel("Loss",fontsize=16) 100 | 101 | # Adjust layout and add a single legend 102 | plt.tight_layout() 103 | handles, labels = ax1.get_legend_handles_labels() 104 | fig.legend(handles, labels, loc='lower center', ncol=3, bbox_to_anchor=(0.52, -0.05), fontsize=16, handlelength=2, columnspacing=1) 105 | 106 | # Adjust subplot spacing to make room for the legend 107 | plt.subplots_adjust(bottom=0.1) 108 | 109 | os.makedirs("training_logs", exist_ok=True) 110 | plt.savefig("training_logs/global_performance.pdf",dpi=300, bbox_inches='tight') 111 | plt.close() 112 | print("Visualization saved as 'training_logs/global_performance.png'") 113 | 114 | # Save metrics to a single JSON file for combined plotting 115 | metrics_data = { 116 | "rounds": self.global_metrics["rounds"], 117 | "avg_rewards": self.global_metrics["avg_reward"], 118 | "avg_losses": self.global_metrics["avg_loss"], 119 | "total_steps": self.global_metrics["total_steps"], 120 | "total_samples": self.global_metrics["total_samples"], 121 | } 122 | os.makedirs("metrics", exist_ok=True) 123 | with open(f"metrics/metrics_federated_k{NUM_CLIENTS}.json", "w") as f: 124 | json.dump(metrics_data, f) 125 | print(f"Federated metrics saved as 'metrics/metrics_federated_k{NUM_CLIENTS}.json'") 126 | 127 | def plot_performance_vs_samples(self): 128 | """Plot global average reward vs total samples.""" 129 | plt.figure(figsize=(10, 6)) 130 | total_samples = self.global_metrics["total_samples"] 131 | avg_rewards = self.global_metrics["avg_reward"] 132 | plt.plot(total_samples, avg_rewards, marker='s', linewidth=2, linestyle='--', label=f'FedRLHF K={NUM_CLIENTS}') 133 | plt.xlabel("Total Samples") 134 | plt.ylabel("Average Reward") 135 | plt.title("Performance vs. Number of Samples") 136 | plt.legend() 137 | plt.grid(True) 138 | os.makedirs("training_logs", exist_ok=True) 139 | plt.savefig(f"training_logs/performance_vs_samples_k{NUM_CLIENTS}.png") 140 | plt.close() 141 | print(f"Performance vs. Samples plot saved as 'training_logs/performance_vs_samples_k{NUM_CLIENTS}.png'") 142 | 143 | def start_federated_server(): 144 | print("Starting federated server") 145 | strategy = CustomFedAvg( 146 | fraction_fit=1.0, 147 | min_fit_clients=NUM_CLIENTS, 148 | min_evaluate_clients=NUM_CLIENTS, 149 | min_available_clients=NUM_CLIENTS, 150 | on_fit_config_fn=lambda rnd: {"round": rnd} # Pass the current round number 151 | ) 152 | 153 | # Start the server 154 | fl.server.start_server( 155 | server_address="0.0.0.0:8080", 156 | config=fl.server.ServerConfig(num_rounds=NUM_ROUNDS), 157 | strategy=strategy, 158 | grpc_max_message_length=1024 * 1024 * 1024 # 1 GB 159 | ) 160 | 161 | def start_federated_simulation(): 162 | print("Starting federated simulation") 163 | 164 | # Import the client class 165 | from client import FedRLHFClient 166 | 167 | # Define the client function 168 | def client_fn(cid: str): 169 | # Convert client ID to integer 170 | client_id = int(cid) 171 | client_lambda_lm = LAMBDA_LMs[client_id] 172 | return FedRLHFClient(client_id=client_id, num_clients=NUM_CLIENTS, num_rounds=NUM_ROUNDS, lambda_lm=client_lambda_lm) 173 | 174 | # Create the strategy 175 | strategy = CustomFedAvg( 176 | fraction_fit=1.0, 177 | min_fit_clients=NUM_CLIENTS, 178 | min_evaluate_clients=NUM_CLIENTS, 179 | min_available_clients=NUM_CLIENTS, 180 | on_fit_config_fn=lambda rnd: {"round": rnd} # Pass the current round number 181 | ) 182 | 183 | # Start the simulation 184 | fl.simulation.start_simulation( 185 | client_fn=client_fn, 186 | num_clients=NUM_CLIENTS, 187 | config=fl.server.ServerConfig(num_rounds=NUM_ROUNDS), 188 | strategy=strategy, 189 | client_resources={"num_cpus": 1}, 190 | ) 191 | 192 | if __name__ == "__main__": 193 | if FLOWER_SIM_MODE == "server": 194 | start_federated_server() 195 | elif FLOWER_SIM_MODE == "simulation": 196 | start_federated_simulation() 197 | else: 198 | raise ValueError(f"Invalid server mode: {FLOWER_SIM_MODE}. Please set FLOWER_SIM_MODE to 'server' or 'simulation'.") 199 | -------------------------------------------------------------------------------- /exp-LLM-IMDB/centralized_training.py: -------------------------------------------------------------------------------- 1 | # centralized_training.py 2 | 3 | import os 4 | import torch 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | import json 8 | from datasets import load_dataset 9 | from transformers import AutoTokenizer, pipeline 10 | from trl import PPOTrainer, AutoModelForCausalLMWithValueHead, PPOConfig 11 | from tqdm import tqdm 12 | import logging 13 | 14 | from config import TRAINER_CONFIG, DATASET_DIVISION, SEED 15 | 16 | # Set up logging 17 | logger = logging.getLogger(__name__) 18 | logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s') 19 | 20 | # Define PPO configuration 21 | ppo_config = PPOConfig( 22 | model_name="gpt2", 23 | batch_size=TRAINER_CONFIG['BATCH_SIZE'], 24 | mini_batch_size=TRAINER_CONFIG['MINI_BATCH_SIZE'], 25 | learning_rate=TRAINER_CONFIG['LEARNING_RATE'], 26 | gradient_accumulation_steps=1, 27 | seed=SEED, 28 | query_dataset="imdb", 29 | dataset_num_proc=4, 30 | ) 31 | 32 | class CentralizedRLHFTrainer: 33 | def __init__(self): 34 | self.model = None 35 | self.tokenizer = None 36 | self.train_dataset = None 37 | self.eval_dataset = None 38 | self.trainer = None 39 | self.sentiment_pipe = None 40 | self.rewards_over_samples = [] 41 | self.losses_over_samples = [] 42 | self.step = 0 43 | self.total_samples = 0 44 | 45 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 46 | logger.info(f"Using device: {self.device}") 47 | 48 | self.verbose = False # Set to True for detailed logging 49 | 50 | self.initialize_components() 51 | 52 | def initialize_components(self): 53 | logger.info("Initializing components for centralized training") 54 | self.model = AutoModelForCausalLMWithValueHead.from_pretrained(ppo_config.model_name) 55 | self.tokenizer = AutoTokenizer.from_pretrained(ppo_config.model_name, padding_side='left', clean_up_tokenization_spaces=True) 56 | self.tokenizer.pad_token = self.tokenizer.eos_token 57 | self.model.config.pad_token_id = self.tokenizer.eos_token_id 58 | self.model.config.padding_side = 'left' 59 | 60 | # Build dataset 61 | self.build_dataset(ppo_config.query_dataset, ppo_config.dataset_num_proc) 62 | logger.info(f"Training dataset size: {len(self.train_dataset)}") 63 | logger.info(f"Evaluation dataset size: {len(self.eval_dataset)}") 64 | 65 | # Move model to device 66 | self.model = self.model.to(self.device) 67 | logger.info(f"Model moved to {self.device}") 68 | 69 | # Initialize sentiment analysis pipeline 70 | self.sentiment_pipe = pipeline("sentiment-analysis", model="lvwerra/distilbert-imdb", device=0 if torch.cuda.is_available() else -1) 71 | logger.info("Sentiment analysis pipeline initialized") 72 | 73 | # Initialize PPO Trainer 74 | self.trainer = PPOTrainer( 75 | config=ppo_config, 76 | model=self.model, 77 | ref_model=None, 78 | tokenizer=self.tokenizer, 79 | dataset=self.train_dataset, 80 | data_collator=self.collator 81 | ) 82 | 83 | def build_dataset(self, query_dataset, dataset_num_proc, input_min_text_length=2, input_max_text_length=8): 84 | logger.info("Building dataset") 85 | full_ds = load_dataset(query_dataset, split="train") 86 | full_ds = full_ds.rename_columns({"text": "review"}) 87 | full_ds = full_ds.filter(lambda x: len(x["review"]) > 200, num_proc=dataset_num_proc) 88 | 89 | # Limit dataset size for testing 90 | max_total_samples = len(full_ds) // DATASET_DIVISION 91 | full_ds = full_ds.select(range(min(len(full_ds), max_total_samples))) 92 | 93 | def tokenize(sample): 94 | encoding = self.tokenizer( 95 | sample["review"], 96 | truncation=True, 97 | max_length=input_max_text_length, 98 | return_tensors="pt", 99 | ) 100 | sample["input_ids"] = encoding["input_ids"][0] 101 | sample["attention_mask"] = encoding["attention_mask"][0] 102 | sample["query"] = self.tokenizer.decode( 103 | sample["input_ids"], 104 | skip_special_tokens=True, 105 | clean_up_tokenization_spaces=True 106 | ) 107 | return sample 108 | 109 | ds = full_ds.map(tokenize, num_proc=dataset_num_proc) 110 | ds.set_format(type="torch", columns=["input_ids", "attention_mask", "query"]) 111 | logger.info(f"Dataset built successfully with {len(ds)} samples") 112 | 113 | # Split into train and eval datasets 114 | split_ds = ds.train_test_split(test_size=0.2, seed=ppo_config.seed) 115 | self.train_dataset = split_ds['train'] 116 | self.eval_dataset = split_ds['test'] 117 | 118 | def collator(self, data): 119 | return { 120 | "input_ids": torch.nn.utils.rnn.pad_sequence([d["input_ids"] for d in data], batch_first=True, padding_value=self.tokenizer.pad_token_id), 121 | "attention_mask": torch.nn.utils.rnn.pad_sequence([torch.ones_like(d["input_ids"]) for d in data], batch_first=True, padding_value=0), 122 | "query": [d["query"] if "query" in d else self.tokenizer.decode(d["input_ids"]) for d in data] 123 | } 124 | 125 | def compute_rewards(self, queries, responses): 126 | texts = [q + r for q, r in zip(queries, responses)] 127 | max_length = 512 128 | 129 | # Sentiment analysis reward 130 | pipe_outputs = self.sentiment_pipe(texts, truncation=True, max_length=max_length) 131 | sentiment_rewards = [] 132 | for i, output in enumerate(pipe_outputs): 133 | if isinstance(output, dict) and "label" in output and "score" in output: 134 | reward = torch.tensor(output["score"] if output["label"] == "POSITIVE" else 1 - output["score"], device=self.device) 135 | else: 136 | logger.info(f"Unexpected output format: {output}") 137 | reward = torch.tensor(0.5, device=self.device) # Default reward if no valid sentiment data 138 | sentiment_rewards.append(reward) 139 | 140 | # Verbose logging for first few samples 141 | if self.verbose and i < 5: 142 | logger.info(f"Sample {i}: Query: {queries[i][:50]}... Response: {responses[i][:50]}... Sentiment Reward: {reward.item():.4f}") 143 | 144 | # Convert list of sentiment rewards to tensor 145 | sentiment_rewards = torch.stack(sentiment_rewards) 146 | 147 | # Calculate intrinsic reward (Negative log probability) 148 | with torch.no_grad(): 149 | # Tokenize responses and move to the device 150 | lm_inputs = self.tokenizer(responses, return_tensors="pt", padding=True, truncation=True).to(self.device) 151 | 152 | # Pass the inputs through the model 153 | lm_outputs = self.model(**lm_inputs) 154 | lm_logits = lm_outputs[0] 155 | 156 | # Get the actual tokens in the responses 157 | input_ids = lm_inputs["input_ids"] 158 | 159 | # Shift input_ids to get the correct next-token targets 160 | shift_labels = input_ids[..., 1:].contiguous() 161 | 162 | # Get the corresponding logits for the shifted tokens 163 | shift_logits = lm_logits[..., :-1, :].contiguous() 164 | 165 | # Compute log probabilities for the actual tokens (shift_labels) 166 | log_probs = torch.nn.functional.cross_entropy( 167 | shift_logits.view(-1, shift_logits.size(-1)), 168 | shift_labels.view(-1), 169 | reduction='none' 170 | ) 171 | 172 | # Reshape log_probs back to [batch_size, seq_len-1] 173 | log_probs = log_probs.view(shift_labels.size()) 174 | 175 | # Compute the intrinsic reward as the negative log probability (mean over tokens in the sequence) 176 | intrinsic_rewards = -log_probs.mean(dim=1) 177 | 178 | # Define the theoretical minimum and maximum intrinsic rewards 179 | V = self.tokenizer.vocab_size # Vocabulary size, e.g., 50,000 180 | min_possible = -torch.log(torch.tensor(V, dtype=torch.float32, device=self.device)) 181 | max_possible = torch.tensor(0.0, device=self.device) 182 | 183 | # Normalize intrinsic rewards to [0, 1] 184 | intrinsic_rewards_norm = (intrinsic_rewards - min_possible) / (max_possible - min_possible) 185 | intrinsic_rewards_norm = torch.clamp(intrinsic_rewards_norm, min=0.0, max=1.0) 186 | 187 | # Compute combined rewards 188 | lambda_lm = 0.5 189 | combined_rewards = lambda_lm * sentiment_rewards + (1 - lambda_lm) * intrinsic_rewards_norm 190 | 191 | # Ensure rewards are lists of tensors 192 | combined_rewards = [r for r in combined_rewards] 193 | 194 | # Log combined reward for debugging 195 | if self.verbose: 196 | logger.info(f"Combined Reward for Batch: {torch.stack(combined_rewards).mean().item():.4f}") 197 | 198 | return combined_rewards 199 | 200 | def train(self): 201 | generation_kwargs = { 202 | "min_length": -1, 203 | "top_k": 0.0, 204 | "top_p": 1.0, 205 | "do_sample": True, 206 | "pad_token_id": self.tokenizer.eos_token_id, 207 | "max_new_tokens": 32, 208 | } 209 | 210 | total_rewards = [] 211 | total_losses = [] 212 | 213 | for epoch in range(TRAINER_CONFIG["num_epochs"]): 214 | logger.info(f"Starting epoch {epoch + 1}/{TRAINER_CONFIG['num_epochs']}") 215 | for batch_idx, batch in enumerate(tqdm(self.trainer.dataloader, desc=f"Epoch {epoch + 1} Training")): 216 | query_tensors = batch["input_ids"].to(self.device) 217 | query_tensor_list = [tensor for tensor in query_tensors] 218 | 219 | response_tensors = self.trainer.generate( 220 | query_tensor_list, 221 | return_prompt=False, 222 | **generation_kwargs 223 | ) 224 | 225 | decoded_queries = self.tokenizer.batch_decode(query_tensors, clean_up_tokenization_spaces=True) 226 | decoded_responses = self.tokenizer.batch_decode(response_tensors, clean_up_tokenization_spaces=True) 227 | 228 | combined_rewards = self.compute_rewards(decoded_queries, decoded_responses) 229 | rewards = combined_rewards 230 | avg_reward = sum([r.item() for r in rewards]) / len(rewards) 231 | total_rewards.extend([r.item() for r in rewards]) 232 | 233 | stats = self.trainer.step(query_tensor_list, response_tensors, rewards) 234 | 235 | loss = stats.get("ppo/loss/total", 0.0) 236 | total_losses.append(loss) 237 | 238 | # Increment the total sample count 239 | batch_size = len(query_tensors) 240 | self.total_samples += batch_size 241 | 242 | # Collect metrics 243 | self.rewards_over_samples.append((self.total_samples, avg_reward)) 244 | self.losses_over_samples.append((self.total_samples, loss)) 245 | 246 | # Increment the total step count 247 | self.step += 1 248 | 249 | if self.verbose: 250 | logger.info(f"Batch {batch_idx + 1}:") 251 | logger.info(f" Average reward for the batch: {avg_reward:.4f}") 252 | logger.info(f" Loss: {loss:.4f}") 253 | 254 | # Plot metrics 255 | self.plot_metrics() 256 | 257 | # Save metrics to a JSON file for combined plotting 258 | metrics_data = { 259 | "rounds": list(range(1, len(self.rewards_over_samples) + 1)), 260 | "avg_rewards": [r for _, r in self.rewards_over_samples], 261 | "avg_losses": [l for _, l in self.losses_over_samples], 262 | "total_steps": [self.step] * len(self.rewards_over_samples), 263 | "total_samples": [s for s, _ in self.rewards_over_samples], 264 | } 265 | os.makedirs("metrics", exist_ok=True) 266 | with open("metrics/metrics_centralized.json", "w") as f: 267 | json.dump(metrics_data, f) 268 | logger.info("Centralized training metrics saved.") 269 | 270 | def plot_metrics(self): 271 | if not self.rewards_over_samples or not self.losses_over_samples: 272 | logger.info("No metrics to plot for centralized training.") 273 | return 274 | 275 | samples_rewards, avg_rewards = zip(*self.rewards_over_samples) 276 | samples_losses, losses = zip(*self.losses_over_samples) 277 | 278 | plt.figure(figsize=(12, 6)) 279 | 280 | # Rewards subplot 281 | plt.subplot(2, 1, 1) 282 | plt.plot(samples_rewards, avg_rewards, marker='o', label='Average Reward') 283 | plt.title("Centralized Training - Average Reward over Samples") 284 | plt.xlabel("Total Samples") 285 | plt.ylabel("Average Reward") 286 | plt.legend() 287 | 288 | # Losses subplot 289 | plt.subplot(2, 1, 2) 290 | plt.plot(samples_losses, losses, marker='o', label='Loss') 291 | plt.title("Centralized Training - Loss over Samples") 292 | plt.xlabel("Total Samples") 293 | plt.ylabel("Loss") 294 | plt.legend() 295 | 296 | plt.tight_layout() 297 | os.makedirs("training_logs", exist_ok=True) 298 | plt.savefig("training_logs/ppo_training_centralized.png") 299 | plt.close() 300 | 301 | logger.info("Centralized training metrics plotted and saved.") 302 | 303 | if __name__ == "__main__": 304 | trainer = CentralizedRLHFTrainer() 305 | trainer.train() -------------------------------------------------------------------------------- /exp-LLM-IMDB/client.py: -------------------------------------------------------------------------------- 1 | # client.py 2 | 3 | import flwr as fl 4 | from flwr.common import ( 5 | FitIns, 6 | FitRes, 7 | Parameters, 8 | GetParametersIns, 9 | GetParametersRes, 10 | EvaluateIns, 11 | EvaluateRes, 12 | NDArrays, 13 | Status, 14 | Code, 15 | ) 16 | from typing import Dict, List 17 | from flwr.common.parameter import ndarrays_to_parameters, parameters_to_ndarrays 18 | 19 | import matplotlib.pyplot as plt 20 | import numpy as np 21 | import torch 22 | from tqdm import tqdm 23 | from datasets import load_dataset 24 | from transformers import AutoTokenizer, pipeline 25 | from trl import PPOTrainer, AutoModelForCausalLMWithValueHead, PPOConfig 26 | import os 27 | from collections import OrderedDict 28 | import logging 29 | import csv 30 | import json 31 | 32 | from config import TRAINER_CONFIG, DATASET_DIVISION, SEED, VERBOSE 33 | 34 | # Set up logging 35 | logger = logging.getLogger(__name__) 36 | logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s') 37 | 38 | # Define PPO configuration 39 | ppo_config = PPOConfig( 40 | model_name="gpt2", 41 | batch_size=TRAINER_CONFIG['BATCH_SIZE'], # Adjust based on GPU memory 42 | mini_batch_size=TRAINER_CONFIG['MINI_BATCH_SIZE'], # Adjust based on GPU memory 43 | learning_rate=TRAINER_CONFIG['LEARNING_RATE'], 44 | gradient_accumulation_steps=1, 45 | seed=SEED, 46 | query_dataset="imdb", 47 | dataset_num_proc=4, 48 | ) 49 | 50 | class FedRLHFClient(fl.client.Client): 51 | def __init__(self, client_id: int, num_clients: int, num_rounds: int, lambda_lm: float): 52 | super().__init__() 53 | logger.info(f"Initializing FedRLHFClient {client_id} with lambda_lm: {lambda_lm}") 54 | self.client_id = client_id 55 | self.num_clients = num_clients 56 | self.num_rounds = num_rounds # Total number of federation rounds 57 | # Initialize metrics lists 58 | self.rewards_over_samples = [] 59 | self.losses_over_samples = [] 60 | # Initialize step counter 61 | self.step = 0 62 | self.total_samples = 0 # Cumulative samples processed 63 | 64 | self.lambda_lm = lambda_lm 65 | 66 | # Initialize components 67 | self.model = None 68 | self.tokenizer = None 69 | self.train_dataset = None 70 | self.eval_dataset = None 71 | self.round_datasets = [] # List to hold data partitions for each round 72 | self.trainer = None 73 | self.sentiment_pipe = None 74 | self.stats_log = {"rewards": [], "losses": []} 75 | self.client_metrics = { 76 | "avg_rewards": [], 77 | "avg_losses": [], 78 | "num_examples": 0, 79 | "total_samples": [], 80 | } 81 | 82 | # Clear previous evaluation logs 83 | self.clear_evaluation_logs() 84 | 85 | self.evaluation_samples = None # List to hold evaluation samples for actual response 86 | 87 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 88 | logger.info(f"Using device: {self.device}") 89 | 90 | self.verbose = VERBOSE # Set to True for detailed logging 91 | 92 | self.initialize_components() 93 | 94 | def initialize_components(self): 95 | logger.info("Initializing components") 96 | try: 97 | # Initialize model and tokenizer 98 | self.model = AutoModelForCausalLMWithValueHead.from_pretrained(ppo_config.model_name) 99 | self.tokenizer = AutoTokenizer.from_pretrained(ppo_config.model_name, padding_side='left', clean_up_tokenization_spaces=True) 100 | self.tokenizer.pad_token = self.tokenizer.eos_token 101 | self.model.config.pad_token_id = self.tokenizer.eos_token_id 102 | self.model.config.padding_side = 'left' 103 | 104 | # Build dataset 105 | self.build_dataset(ppo_config.query_dataset, ppo_config.dataset_num_proc) 106 | logger.info(f"Training dataset size: {len(self.train_dataset)}") 107 | logger.info(f"Evaluation dataset size: {len(self.eval_dataset)}") 108 | 109 | # Partition the training data into num_rounds parts 110 | self.partition_data() 111 | 112 | # Move model to device 113 | self.model = self.model.to(self.device) 114 | logger.info(f"Model moved to {self.device}") 115 | 116 | # Initialize sentiment analysis pipeline 117 | self.sentiment_pipe = pipeline("sentiment-analysis", model="lvwerra/distilbert-imdb", device=0 if torch.cuda.is_available() else -1) 118 | logger.info("Sentiment analysis pipeline initialized") 119 | 120 | # Sample evaluation data once for consistent evaluation 121 | self.sample_evaluation_data(num_samples=30) 122 | except Exception as e: 123 | logger.info(f"Error in initialize_components: {str(e)}") 124 | raise 125 | 126 | def build_dataset(self, query_dataset, dataset_num_proc, input_min_text_length=2, input_max_text_length=8): 127 | logger.info("Building dataset") 128 | full_ds = load_dataset(query_dataset, split="train") 129 | full_ds = full_ds.rename_columns({"text": "review"}) 130 | full_ds = full_ds.filter(lambda x: len(x["review"]) > 200, num_proc=dataset_num_proc) 131 | 132 | # Limit dataset size for testing 133 | max_total_samples = len(full_ds) // DATASET_DIVISION 134 | full_ds = full_ds.select(range(min(len(full_ds), max_total_samples))) 135 | 136 | # Partition dataset among clients 137 | total_size = len(full_ds) 138 | partition_size = total_size // self.num_clients 139 | 140 | start_idx = self.client_id * partition_size 141 | end_idx = start_idx + partition_size if self.client_id < self.num_clients - 1 else total_size 142 | 143 | ds = full_ds.select(range(start_idx, end_idx)) 144 | 145 | def tokenize(sample): 146 | encoding = self.tokenizer( 147 | sample["review"], 148 | truncation=True, 149 | max_length=input_max_text_length, 150 | return_tensors="pt", 151 | ) 152 | sample["input_ids"] = encoding["input_ids"][0] 153 | sample["attention_mask"] = encoding["attention_mask"][0] 154 | sample["query"] = self.tokenizer.decode( 155 | sample["input_ids"], 156 | skip_special_tokens=True, 157 | clean_up_tokenization_spaces=True 158 | ) 159 | return sample 160 | 161 | ds = ds.map(tokenize, num_proc=dataset_num_proc) 162 | ds.set_format(type="torch", columns=["input_ids", "attention_mask", "query"]) 163 | logger.info(f"Dataset built successfully for client {self.client_id} with {len(ds)} samples") 164 | 165 | # Split into train and eval datasets 166 | split_ds = ds.train_test_split(test_size=0.2, seed=ppo_config.seed) 167 | self.train_dataset = split_ds['train'] 168 | self.eval_dataset = split_ds['test'] 169 | 170 | def partition_data(self): 171 | # Partition the training data into num_rounds parts 172 | num_samples = len(self.train_dataset) 173 | samples_per_round = num_samples // self.num_rounds 174 | self.round_datasets = [] 175 | 176 | for i in range(self.num_rounds): 177 | start_idx = i * samples_per_round 178 | end_idx = start_idx + samples_per_round if i < self.num_rounds - 1 else num_samples 179 | round_dataset = self.train_dataset.select(range(start_idx, end_idx)) 180 | self.round_datasets.append(round_dataset) 181 | logger.info(f"Client {self.client_id} - Round {i+1}: {len(round_dataset)} samples") 182 | 183 | def collator(self, data): 184 | return { 185 | "input_ids": torch.nn.utils.rnn.pad_sequence([d["input_ids"] for d in data], batch_first=True, padding_value=self.tokenizer.pad_token_id), 186 | "attention_mask": torch.nn.utils.rnn.pad_sequence([torch.ones_like(d["input_ids"]) for d in data], batch_first=True, padding_value=0), 187 | "query": [d["query"] if "query" in d else self.tokenizer.decode(d["input_ids"]) for d in data] 188 | } 189 | 190 | def compute_rewards(self, queries, responses): 191 | texts = [q + r for q, r in zip(queries, responses)] 192 | max_length = 512 193 | 194 | # Sentiment analysis reward 195 | pipe_outputs = self.sentiment_pipe(texts, truncation=True, max_length=max_length) 196 | sentiment_rewards = [] 197 | for i, output in enumerate(pipe_outputs): 198 | if isinstance(output, dict) and "label" in output and "score" in output: 199 | reward = torch.tensor(output["score"] if output["label"] == "POSITIVE" else 1 - output["score"], device=self.device) 200 | else: 201 | logger.info(f"Unexpected output format: {output}") 202 | reward = torch.tensor(0.5, device=self.device) # Default reward if no valid sentiment data 203 | sentiment_rewards.append(reward) 204 | 205 | # Verbose logging for first few samples 206 | if self.verbose and i < 5: 207 | logger.info(f"Sample {i}: Query: {queries[i][:50]}... Response: {responses[i][:50]}... Sentiment Reward: {reward.item():.4f}") 208 | 209 | # Convert list of sentiment rewards to tensor 210 | sentiment_rewards = torch.stack(sentiment_rewards) 211 | 212 | # Calculate intrinsic reward (Negative log probability) 213 | with torch.no_grad(): 214 | # Tokenize responses and move to the device 215 | lm_inputs = self.tokenizer(responses, return_tensors="pt", padding=True, truncation=True).to(self.device) 216 | 217 | # Pass the inputs through the model (assuming model outputs a tuple) 218 | lm_outputs = self.model(**lm_inputs) 219 | lm_logits = lm_outputs[0] 220 | 221 | # Get the actual tokens in the responses (i.e., target tokens for calculating probabilities) 222 | input_ids = lm_inputs["input_ids"] 223 | 224 | # Shift input_ids to get the correct next-token targets 225 | shift_labels = input_ids[..., 1:].contiguous() 226 | 227 | # Get the corresponding logits for the shifted tokens 228 | shift_logits = lm_logits[..., :-1, :].contiguous() 229 | 230 | # Compute log probabilities for the actual tokens (shift_labels) 231 | log_probs = torch.nn.functional.cross_entropy( 232 | shift_logits.view(-1, shift_logits.size(-1)), 233 | shift_labels.view(-1), 234 | reduction='none' 235 | ) 236 | 237 | # Reshape log_probs back to [batch_size, seq_len-1] 238 | log_probs = log_probs.view(shift_labels.size()) 239 | 240 | # Compute the intrinsic reward as the negative log probability (mean over tokens in the sequence) 241 | intrinsic_rewards = -log_probs.mean(dim=1) 242 | 243 | # Define the theoretical minimum and maximum intrinsic rewards 244 | V = self.tokenizer.vocab_size # Vocabulary size, e.g., 50,000 245 | min_possible = -torch.log(torch.tensor(V, dtype=torch.float32, device=self.device)) 246 | max_possible = torch.tensor(0.0, device=self.device) 247 | 248 | # Normalize intrinsic rewards to [0, 1] 249 | intrinsic_rewards_norm = (intrinsic_rewards - min_possible) / (max_possible - min_possible) 250 | intrinsic_rewards_norm = torch.clamp(intrinsic_rewards_norm, min=0.0, max=1.0) 251 | 252 | # Compute combined rewards 253 | combined_rewards = self.lambda_lm * sentiment_rewards + (1 - self.lambda_lm) * intrinsic_rewards_norm 254 | 255 | # Ensure rewards are lists of tensors 256 | combined_rewards = [r for r in combined_rewards] 257 | sentiment_rewards = [r for r in sentiment_rewards] 258 | intrinsic_rewards = [r for r in intrinsic_rewards_norm] 259 | 260 | # Log combined reward for debugging 261 | if self.verbose: 262 | logger.info(f"Combined Reward for Batch: {torch.stack(combined_rewards).mean().item():.4f}") 263 | 264 | return combined_rewards, sentiment_rewards, intrinsic_rewards_norm 265 | 266 | def plot_metrics(self): 267 | if not self.rewards_over_samples or not self.losses_over_samples: 268 | logger.info(f"No metrics to plot for client {self.client_id}.") 269 | return 270 | 271 | samples_rewards, avg_rewards = zip(*self.rewards_over_samples) 272 | samples_losses, losses = zip(*self.losses_over_samples) 273 | 274 | plt.figure(figsize=(12, 6)) 275 | 276 | # Rewards subplot 277 | plt.subplot(2, 1, 1) 278 | plt.plot(samples_rewards, avg_rewards, marker='o', label='Average Reward') 279 | plt.title(f"Client {self.client_id} - Average Reward over Samples") 280 | plt.xlabel("Total Samples") 281 | plt.ylabel("Average Reward") 282 | plt.legend() 283 | 284 | # Losses subplot 285 | plt.subplot(2, 1, 2) 286 | plt.plot(samples_losses, losses, marker='o', label='Loss') 287 | plt.title(f"Client {self.client_id} - Loss over Samples") 288 | plt.xlabel("Total Samples") 289 | plt.ylabel("Loss") 290 | plt.legend() 291 | 292 | plt.tight_layout() 293 | os.makedirs("training_logs", exist_ok=True) 294 | plt.savefig(f"training_logs/ppo_training_client_{self.client_id}.png") 295 | plt.close() 296 | 297 | logger.info(f"Training metrics plotted and saved for client {self.client_id}.") 298 | 299 | # Save metrics to a JSON file for individual client analysis 300 | metrics_data = { 301 | "total_samples": [s for s, _ in self.rewards_over_samples], 302 | "avg_rewards": [r for _, r in self.rewards_over_samples], 303 | "losses": [l for _, l in self.losses_over_samples], 304 | "client_id": self.client_id, 305 | } 306 | os.makedirs("metrics", exist_ok=True) 307 | with open(f"metrics/metrics_client_{self.client_id}.json", "w") as f: 308 | json.dump(metrics_data, f) 309 | logger.info(f"Metrics saved for client {self.client_id}.") 310 | 311 | def fit(self, ins: FitIns) -> FitRes: 312 | self.set_parameters(ins.parameters) 313 | 314 | # Retrieve the current round number from FitIns.config 315 | round_num = int(ins.config.get("round", 1)) 316 | logger.info(f"Client {self.client_id} - Starting training for round {round_num}") 317 | 318 | # Select the data partition for the current round 319 | if round_num <= len(self.round_datasets): 320 | current_dataset = self.round_datasets[round_num - 1] 321 | else: 322 | current_dataset = self.round_datasets[-1] 323 | 324 | # Initialize PPO Trainer with the current round's dataset 325 | self.trainer = PPOTrainer( 326 | config=ppo_config, 327 | model=self.model, 328 | ref_model=None, 329 | tokenizer=self.tokenizer, 330 | dataset=current_dataset, 331 | data_collator=self.collator 332 | ) 333 | 334 | generation_kwargs = { 335 | "min_length": -1, 336 | "top_k": 0.0, 337 | "top_p": 1.0, 338 | "do_sample": True, 339 | "pad_token_id": self.tokenizer.eos_token_id, 340 | "max_new_tokens": 32, 341 | } 342 | 343 | total_rewards = [] 344 | total_losses = [] 345 | num_examples = 0 # Number of samples processed in this round 346 | 347 | for batch_idx, batch in enumerate(tqdm(self.trainer.dataloader, desc=f"Client {self.client_id} - Round {round_num} Training")): 348 | query_tensors = batch["input_ids"].to(self.device) 349 | query_tensor_list = [tensor for tensor in query_tensors] 350 | 351 | response_tensors = self.trainer.generate( 352 | query_tensor_list, 353 | return_prompt=False, 354 | **generation_kwargs 355 | ) 356 | 357 | decoded_queries = self.tokenizer.batch_decode(query_tensors, clean_up_tokenization_spaces=True) 358 | decoded_responses = self.tokenizer.batch_decode(response_tensors, clean_up_tokenization_spaces=True) 359 | 360 | combined_rewards, _, _ = self.compute_rewards(decoded_queries, decoded_responses) 361 | rewards = combined_rewards 362 | avg_reward = sum([r.item() for r in rewards]) / len(rewards) 363 | total_rewards.extend([r.item() for r in rewards]) 364 | 365 | stats = self.trainer.step(query_tensor_list, response_tensors, rewards) 366 | 367 | loss = stats.get("ppo/loss/total", 0.0) 368 | total_losses.append(loss) 369 | 370 | # Increment the total sample count 371 | batch_size = len(query_tensors) 372 | self.total_samples += batch_size 373 | num_examples += batch_size 374 | 375 | # Collect metrics 376 | self.rewards_over_samples.append((self.total_samples, avg_reward)) 377 | self.losses_over_samples.append((self.total_samples, loss)) 378 | 379 | if self.verbose: 380 | logger.info(f"Batch {batch_idx + 1}:") 381 | logger.info(f" Average reward for the batch: {avg_reward:.4f}") 382 | logger.info(f" Loss: {loss:.4f}") 383 | 384 | # Increment the total step count 385 | self.step += 1 386 | 387 | logger.info(f"Training completed for client {self.client_id} - Round {round_num}") 388 | 389 | # Save metrics to file after each round 390 | self.plot_metrics() 391 | 392 | # Save the model after training 393 | self.save_model(round_num) 394 | 395 | # Sample and evaluate on a few samples from eval dataset 396 | self.sample_and_evaluate(round_num) 397 | 398 | # Update total number of examples processed 399 | self.client_metrics["num_examples"] += num_examples 400 | self.client_metrics["total_samples"].append(self.total_samples) 401 | 402 | # Calculate average reward and loss for this round 403 | avg_reward = sum(total_rewards) / len(total_rewards) if total_rewards else 0.0 404 | avg_loss = sum(total_losses) / len(total_losses) if total_losses else 0.0 405 | 406 | self.client_metrics["avg_rewards"].append(avg_reward) 407 | self.client_metrics["avg_losses"].append(avg_loss) 408 | 409 | # Prepare metrics to send to the server 410 | metrics = { 411 | "client_id": self.client_id, 412 | "avg_reward": avg_reward, 413 | "avg_loss": avg_loss, 414 | "num_examples": num_examples, 415 | "total_steps": self.step, 416 | "total_samples": self.total_samples 417 | } 418 | 419 | parameters = self.get_parameters(GetParametersIns(config={})).parameters 420 | return FitRes( 421 | parameters=parameters, 422 | num_examples=num_examples, 423 | metrics=metrics, # Ensure metrics are passed here 424 | status=Status(code=Code.OK, message="Success") 425 | ) 426 | 427 | def clear_evaluation_logs(self): 428 | # Delete the client's CSV file in evaluation_logs directory 429 | filename = f"evaluation_logs/client_{self.client_id}.csv" 430 | if os.path.exists(filename): 431 | os.remove(filename) 432 | logger.info(f"Previous evaluation log {filename} removed.") 433 | else: 434 | logger.info(f"No previous evaluation log to remove for client {self.client_id}.") 435 | 436 | def sample_evaluation_data(self, num_samples=5): 437 | # Sample a few samples from the eval dataset 438 | self.evaluation_samples = self.eval_dataset.shuffle(seed=SEED).select(range(num_samples)) 439 | logger.info(f"Sampled {num_samples} evaluation samples for client {self.client_id}") 440 | 441 | def sample_and_evaluate(self, round_num): 442 | # Use the pre-sampled evaluation data 443 | sampled_dataset = self.evaluation_samples 444 | 445 | # Initialize lists to store results 446 | results = [] 447 | 448 | # Iterate over samples 449 | for sample in sampled_dataset: 450 | query_tensor = sample["input_ids"].unsqueeze(0).to(self.device) 451 | query = sample["query"] 452 | 453 | # Generate response 454 | generation_kwargs = { 455 | "min_length": -1, 456 | "top_k": 0.0, 457 | "top_p": 1.0, 458 | "do_sample": True, 459 | "pad_token_id": self.tokenizer.eos_token_id, 460 | "max_new_tokens": 32, 461 | } 462 | response_tensor = self.model.generate( 463 | query_tensor, 464 | **generation_kwargs 465 | ) 466 | generated_tokens = response_tensor[0] 467 | generated_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=True) 468 | # Extract the response by removing the query from the generated text 469 | if generated_text.startswith(query): 470 | response = generated_text[len(query):] 471 | else: 472 | response = generated_text # Fallback in case the query is not at the start 473 | 474 | # Compute rewards 475 | combined_reward, sentiment_reward, intrinsic_reward = self.compute_rewards([query], [response]) 476 | combined_reward = combined_reward[0].item() 477 | sentiment_reward = sentiment_reward[0].item() 478 | intrinsic_reward = intrinsic_reward[0].item() 479 | 480 | # Store the result 481 | results.append({ 482 | "client_id": self.client_id, 483 | "round": round_num, 484 | "query": query, 485 | "response": response, 486 | "sentiment_reward": sentiment_reward, 487 | "intrinsic_reward": intrinsic_reward, 488 | "combined_reward": combined_reward 489 | }) 490 | 491 | # Save results to a single CSV file per client 492 | os.makedirs("evaluation_logs", exist_ok=True) 493 | filename = f"evaluation_logs/client_{self.client_id}.csv" 494 | fieldnames = ["client_id", "round", "query", "response", "sentiment_reward", "intrinsic_reward", "combined_reward"] 495 | 496 | # Always open in append mode since we cleared the file at the start 497 | with open(filename, mode='a', newline='', encoding='utf-8') as csvfile: 498 | writer = csv.DictWriter(csvfile, fieldnames=fieldnames) 499 | # Write header only if the file does not exist or is empty 500 | if os.stat(filename).st_size == 0: 501 | writer.writeheader() 502 | for row in results: 503 | writer.writerow(row) 504 | 505 | logger.info(f"Evaluation samples appended for client {self.client_id} to {filename}") 506 | 507 | def save_model(self, round_num): 508 | # Create the directory if it doesn't exist 509 | save_dir = "trained_models" 510 | os.makedirs(save_dir, exist_ok=True) 511 | 512 | # Define the model save path 513 | model_save_path = os.path.join(save_dir, f"client_{self.client_id}_round_{round_num}") 514 | 515 | # Save the model and tokenizer 516 | self.model.save_pretrained(model_save_path) 517 | self.tokenizer.save_pretrained(model_save_path) 518 | logger.info(f"Model saved for client {self.client_id} at {model_save_path}") 519 | 520 | def get_parameters(self, ins: GetParametersIns) -> GetParametersRes: 521 | ndarrays = [] 522 | for name, param in self.model.named_parameters(): 523 | if name in self.model.state_dict(): 524 | ndarrays.append(param.detach().cpu().numpy()) 525 | 526 | parameters = ndarrays_to_parameters(ndarrays) 527 | return GetParametersRes(parameters=parameters, status=Status(code=Code.OK, message="Success")) 528 | 529 | def get_parameters_as_ndarrays(self) -> List[np.ndarray]: 530 | return [val.cpu().numpy() for _, val in self.model.state_dict().items()] 531 | 532 | def set_parameters(self, parameters: Parameters) -> None: 533 | params_dict = zip(self.model.state_dict().keys(), parameters.tensors) 534 | state_dict = OrderedDict() 535 | for k, v in params_dict: 536 | if k in self.model.state_dict(): 537 | try: 538 | state_dict[k] = torch.tensor(np.frombuffer(v, dtype=np.float32).reshape(self.model.state_dict()[k].shape)) 539 | except ValueError: 540 | logger.info(f"Ignoring parameter {k} due to shape mismatch") 541 | self.model.load_state_dict(state_dict, strict=False) 542 | 543 | def evaluate(self, ins: EvaluateIns) -> EvaluateRes: 544 | self.set_parameters(ins.parameters) 545 | 546 | generation_kwargs = { 547 | "min_length": -1, 548 | "top_k": 0.0, 549 | "top_p": 1.0, 550 | "do_sample": True, 551 | "pad_token_id": self.tokenizer.eos_token_id, 552 | "max_new_tokens": 32, 553 | } 554 | 555 | total_rewards = [] 556 | total_losses = [] 557 | num_samples = 0 558 | 559 | eval_dataloader = torch.utils.data.DataLoader( 560 | self.eval_dataset, 561 | batch_size=ppo_config.batch_size, 562 | collate_fn=self.collator 563 | ) 564 | 565 | self.model.eval() 566 | 567 | for batch in tqdm(eval_dataloader, desc=f"Client {self.client_id} - Evaluation"): 568 | query_tensors = batch["input_ids"].to(self.device) 569 | query_tensor_list = [tensor for tensor in query_tensors] 570 | 571 | # Generate responses 572 | response_tensors = self.trainer.generate( 573 | query_tensor_list, 574 | return_prompt=False, 575 | **generation_kwargs 576 | ) 577 | 578 | # Decode queries and responses 579 | decoded_queries = self.tokenizer.batch_decode(query_tensors, clean_up_tokenization_spaces=True) 580 | decoded_responses = self.tokenizer.batch_decode(response_tensors, clean_up_tokenization_spaces=True) 581 | 582 | # Compute rewards 583 | combined_rewards, _, _ = self.compute_rewards(decoded_queries, decoded_responses) 584 | rewards = combined_rewards 585 | rewards = torch.stack(rewards).to(self.device) 586 | avg_reward = rewards.mean().item() 587 | total_rewards.extend([r.item() for r in rewards]) 588 | 589 | # Approximate loss as negative average reward 590 | loss = -avg_reward 591 | total_losses.append(loss) 592 | num_samples += len(rewards) 593 | 594 | if self.verbose: 595 | logger.info(f"Batch Evaluation: Average Reward: {avg_reward:.4f}, Loss: {loss:.4f}") 596 | 597 | # Compute average reward and loss 598 | average_reward = sum(total_rewards) / num_samples if num_samples > 0 else 0.0 599 | avg_loss = sum(total_losses) / len(total_losses) if total_losses else 0.0 600 | 601 | logger.info(f"Evaluation completed for client {self.client_id}. Average Reward: {average_reward:.4f}, Average Loss: {avg_loss:.4f}") 602 | 603 | return EvaluateRes( 604 | loss=avg_loss, 605 | num_examples=num_samples, 606 | metrics={ 607 | "client_id": self.client_id, 608 | "average_reward": average_reward, 609 | "avg_loss": avg_loss, 610 | "num_examples": num_samples 611 | }, 612 | status=Status(code=Code.OK, message="Success") 613 | ) 614 | 615 | def set_parameters_from_ndarrays(self, params: NDArrays) -> None: 616 | params_dict = zip(self.model.state_dict().keys(), params) 617 | state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) 618 | self.model.load_state_dict(state_dict, strict=False) 619 | --------------------------------------------------------------------------------