├── requirements.txt ├── fed_flwr ├── config │ ├── s_config.yml │ └── c_config.yml ├── run_FEDORA.sh ├── utils │ ├── nets.py │ ├── flwr_utils.py │ └── rl_utils.py ├── rl_server.py └── rl_client.py ├── fed_sim ├── utils │ └── rl_utils.py ├── rl_main.py └── FLAlgorithms │ ├── users │ └── userrl.py │ └── servers │ └── rlserver.py └── README.md /requirements.txt: -------------------------------------------------------------------------------- 1 | d4rl==1.1 2 | flwr==1.2.0 3 | gym==0.24.1 4 | matplotlib==3.5.2 5 | numpy==1.21.6 6 | PyYAML==6.0 7 | torch==1.13.1 8 | -------------------------------------------------------------------------------- /fed_flwr/config/s_config.yml: -------------------------------------------------------------------------------- 1 | env_1: "hopper-expert-v2" 2 | env_2: "hopper-medium-v2" 3 | n_rounds: 1000 4 | n_clients: 10 5 | ncpr: 10 6 | temp_a: 0.1 7 | temp_c: 0.1 8 | dataset_size: 5000 -------------------------------------------------------------------------------- /fed_flwr/config/c_config.yml: -------------------------------------------------------------------------------- 1 | server_ip: "127.0.0.1:8080" # if running locally, also try "[::]:8080" 2 | seed: 1 3 | discount: 0.99 4 | tau: 0.005 5 | policy_noise_f: 0.2 6 | noise_clip_f: 0.5 7 | policy_freq: 2 8 | alpha: 2.5 9 | alpha_0: 1.0 10 | alpha_1: 0.0 11 | alpha_2: 1.0 12 | batch_size: 256 13 | decay_rate: 0.995 14 | l_r: 0.0003 15 | local_epochs: 20 -------------------------------------------------------------------------------- /fed_flwr/run_FEDORA.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | # initialize conda env 5 | source activate fed 6 | 7 | # start server, wait before launching clients 8 | python rl_server.py & 9 | sleep 3 10 | 11 | # start clients 12 | for i in `seq 1 5`; do 13 | echo "Starting client $i" 14 | python rl_client.py --gpu-index 0 --eval-env hopper-expert-v2 \ 15 | --start-index $(((i-1)*5000)) --stop-index $((i*5000)) & 16 | done 17 | 18 | for i in `seq 6 10`; do 19 | echo "Starting client $i" 20 | python rl_client.py --gpu-index 1 --eval-env hopper-medium-v2 \ 21 | --start-index $(((i-1)*5000)) --stop-index $((i*5000)) & 22 | done 23 | 24 | # enable CTRL+C to stop all background processes 25 | trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM 26 | 27 | # wait for all background processes to complete 28 | wait -------------------------------------------------------------------------------- /fed_flwr/utils/nets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class Actor(nn.Module): 7 | def __init__(self, state_dim, action_dim, max_action): 8 | super(Actor, self).__init__() 9 | 10 | self.l1 = nn.Linear(state_dim, 256) 11 | self.l2 = nn.Linear(256, 256) 12 | self.l3 = nn.Linear(256, action_dim) 13 | 14 | self.max_action = max_action 15 | 16 | 17 | def forward(self, state): 18 | a = F.relu(self.l1(state)) 19 | a = F.relu(self.l2(a)) 20 | return self.max_action * torch.tanh(self.l3(a)) 21 | 22 | 23 | class Critic(nn.Module): 24 | def __init__(self, state_dim, action_dim): 25 | super(Critic, self).__init__() 26 | 27 | # Q1 architecture 28 | self.l1 = nn.Linear(state_dim + action_dim, 256) 29 | self.l2 = nn.Linear(256, 256) 30 | self.l3 = nn.Linear(256, 1) 31 | 32 | # Q2 architecture 33 | self.l4 = nn.Linear(state_dim + action_dim, 256) 34 | self.l5 = nn.Linear(256, 256) 35 | self.l6 = nn.Linear(256, 1) 36 | 37 | 38 | def forward(self, state, action): 39 | sa = torch.cat([state, action], 1) 40 | 41 | q1 = F.relu(self.l1(sa)) 42 | q1 = F.relu(self.l2(q1)) 43 | q1 = self.l3(q1) 44 | 45 | q2 = F.relu(self.l4(sa)) 46 | q2 = F.relu(self.l5(q2)) 47 | q2 = self.l6(q2) 48 | return q1, q2 49 | 50 | 51 | def Q1(self, state, action): 52 | sa = torch.cat([state, action], 1) 53 | 54 | q1 = F.relu(self.l1(sa)) 55 | q1 = F.relu(self.l2(q1)) 56 | q1 = self.l3(q1) 57 | return q1 -------------------------------------------------------------------------------- /fed_flwr/utils/flwr_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Dict, List, Optional, Tuple, Union, OrderedDict 2 | from flwr.server.client_manager import ClientManager 3 | from flwr.server.client_proxy import ClientProxy 4 | from flwr.server.strategy.aggregate import aggregate 5 | from flwr.common.logger import log 6 | from flwr.common import ( 7 | EvaluateIns, 8 | EvaluateRes, 9 | FitIns, 10 | FitRes, 11 | MetricsAggregationFn, 12 | Parameters, 13 | Scalar, 14 | parameters_to_ndarrays, 15 | ndarrays_to_parameters, 16 | NDArrays 17 | ) 18 | from functools import reduce 19 | from logging import WARNING 20 | import numpy as np 21 | 22 | def aggregate_rl(results: List[Tuple[NDArrays, int]], pol_val: List[float], \ 23 | len_param_actor: int, temp_a: float=1.0, temp_c: float=1.0) -> NDArrays: 24 | """Compute exponentiated weighted average.""" 25 | 26 | results_val = list(zip([weights for weights, _ in results], pol_val)) 27 | 28 | results_a = [(weights[:len_param_actor], pol_val) \ 29 | for weights, pol_val in results_val] 30 | results_c = [(weights[len_param_actor:], pol_val) \ 31 | for weights, pol_val in results_val] 32 | 33 | # Calculate the total exponentiated value from training 34 | exp_val_a_total = sum([np.exp(temp_a * pol_val) for _, pol_val in results_a]) 35 | exp_val_c_total = sum([np.exp(temp_c * pol_val) for _, pol_val in results_c]) 36 | 37 | # Create a list of weights, each multiplied by the related policy values 38 | weighted_weights_a = [ 39 | [layer * np.exp(temp_a * pol_val) for layer in weights] \ 40 | for weights, pol_val in results_a 41 | ] 42 | weighted_weights_c = [ 43 | [layer * np.exp(temp_c * pol_val) for layer in weights] \ 44 | for weights, pol_val in results_c 45 | ] 46 | 47 | # Compute average weights of each layer 48 | weights_prime_a: NDArrays = [ 49 | reduce(np.add, layer_updates) / exp_val_a_total 50 | for layer_updates in zip(*weighted_weights_a) 51 | ] 52 | weights_prime_c: NDArrays = [ 53 | reduce(np.add, layer_updates) / exp_val_c_total 54 | for layer_updates in zip(*weighted_weights_c) 55 | ] 56 | weights_prime = weights_prime_a + weights_prime_c 57 | return weights_prime -------------------------------------------------------------------------------- /fed_sim/utils/rl_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class ReplayBuffer(object): 6 | def __init__(self, state_dim, action_dim, max_size=int(1e6),gpu_index = 0): 7 | self.max_size = max_size 8 | self.ptr = 0 9 | self.size = 0 10 | 11 | self.state = np.zeros((max_size, state_dim)) 12 | self.action = np.zeros((max_size, action_dim)) 13 | self.next_state = np.zeros((max_size, state_dim)) 14 | self.reward = np.zeros((max_size, 1)) 15 | self.not_done = np.zeros((max_size, 1)) 16 | self.device = torch.device('cuda', index=gpu_index) if torch.cuda.is_available() else torch.device('cpu') 17 | 18 | 19 | 20 | def add(self, state, action, next_state, reward, done): 21 | self.state[self.ptr] = state 22 | self.action[self.ptr] = action 23 | self.next_state[self.ptr] = next_state 24 | self.reward[self.ptr] = reward 25 | self.not_done[self.ptr] = 1. - done 26 | 27 | self.ptr = (self.ptr + 1) % self.max_size 28 | self.size = min(self.size + 1, self.max_size) 29 | 30 | 31 | def sample(self, batch_size): 32 | ind = np.random.randint(0, self.size, size=batch_size) 33 | 34 | return ( 35 | torch.FloatTensor(self.state[ind]).to(self.device), 36 | torch.FloatTensor(self.action[ind]).to(self.device), 37 | torch.FloatTensor(self.next_state[ind]).to(self.device), 38 | torch.FloatTensor(self.reward[ind]).to(self.device), 39 | torch.FloatTensor(self.not_done[ind]).to(self.device) 40 | ) 41 | 42 | 43 | def convert_D4RL(self, dataset,lower_lim=None,upper_lim=None): 44 | if ((lower_lim is None) and (upper_lim is None)): 45 | self.state = dataset['observations'] 46 | print(len(self.state)) 47 | self.action = dataset['actions'] 48 | self.next_state = dataset['next_observations'] 49 | self.reward = dataset['rewards'].reshape(-1,1) 50 | self.not_done = 1. - dataset['terminals'].reshape(-1,1) 51 | self.size = self.state.shape[0] 52 | 53 | else: 54 | self.state = dataset['observations'][lower_lim:upper_lim] 55 | self.action = dataset['actions'][lower_lim:upper_lim] 56 | self.next_state = dataset['next_observations'][lower_lim:upper_lim] 57 | self.reward = dataset['rewards'][lower_lim:upper_lim].reshape(-1,1) 58 | self.not_done = 1. - dataset['terminals'][lower_lim:upper_lim].reshape(-1,1) 59 | self.size = self.state.shape[0] 60 | print("Length of data set:",self.size) 61 | 62 | def normalize_states(self, eps = 1e-3): 63 | mean = self.state.mean(0,keepdims=True) 64 | std = self.state.std(0,keepdims=True) + eps 65 | self.state = (self.state - mean)/std 66 | self.next_state = (self.next_state - mean)/std 67 | return mean, std -------------------------------------------------------------------------------- /fed_flwr/utils/rl_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class ReplayBuffer(object): 6 | def __init__(self, state_dim, action_dim, max_size=int(1e6),gpu_index = 0): 7 | self.max_size = max_size 8 | self.ptr = 0 9 | self.size = 0 10 | 11 | self.state = np.zeros((max_size, state_dim)) 12 | self.action = np.zeros((max_size, action_dim)) 13 | self.next_state = np.zeros((max_size, state_dim)) 14 | self.reward = np.zeros((max_size, 1)) 15 | self.not_done = np.zeros((max_size, 1)) 16 | self.device = torch.device('cuda', index=gpu_index) \ 17 | if (torch.cuda.is_available() and gpu_index > -1) \ 18 | else torch.device('cpu') 19 | 20 | 21 | def add(self, state, action, next_state, reward, done): 22 | self.state[self.ptr] = state 23 | self.action[self.ptr] = action 24 | self.next_state[self.ptr] = next_state 25 | self.reward[self.ptr] = reward 26 | self.not_done[self.ptr] = 1. - done 27 | 28 | self.ptr = (self.ptr + 1) % self.max_size 29 | self.size = min(self.size + 1, self.max_size) 30 | 31 | 32 | def sample(self, batch_size): 33 | ind = np.random.randint(0, self.size, size=batch_size) 34 | 35 | return ( 36 | torch.FloatTensor(self.state[ind]).to(self.device), 37 | torch.FloatTensor(self.action[ind]).to(self.device), 38 | torch.FloatTensor(self.next_state[ind]).to(self.device), 39 | torch.FloatTensor(self.reward[ind]).to(self.device), 40 | torch.FloatTensor(self.not_done[ind]).to(self.device) 41 | ) 42 | 43 | 44 | def convert_D4RL(self, dataset,lower_lim=None,upper_lim=None): 45 | if ((lower_lim is None) and (upper_lim is None)): 46 | self.state = dataset['observations'] 47 | print(len(self.state)) 48 | self.action = dataset['actions'] 49 | self.next_state = dataset['next_observations'] 50 | self.reward = dataset['rewards'].reshape(-1,1) 51 | self.not_done = 1. - dataset['terminals'].reshape(-1,1) 52 | self.size = self.state.shape[0] 53 | 54 | else: 55 | self.state = dataset['observations'][lower_lim:upper_lim] 56 | self.action = dataset['actions'][lower_lim:upper_lim] 57 | self.next_state = dataset['next_observations'][lower_lim:upper_lim] 58 | self.reward = dataset['rewards'][lower_lim:upper_lim].reshape(-1,1) 59 | self.not_done = 1. - dataset['terminals'][lower_lim:upper_lim].reshape(-1,1) 60 | self.size = self.state.shape[0] 61 | print("Length of data set:",self.size) 62 | 63 | def normalize_states(self, eps = 1e-3): 64 | mean = self.state.mean(0,keepdims=True) 65 | std = self.state.std(0,keepdims=True) + eps 66 | self.state = (self.state - mean)/std 67 | self.next_state = (self.next_state - mean)/std 68 | return mean, std -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [Federated Ensemble-Directed Offline Reinforcement Learning](https://arxiv.org/abs/2305.03097) 2 | 3 | Acceteped at the Thirty-Eighth Annual Conference on Neural Information Processing Systems (NeurIPS 2024). 4 | 5 | [Video of real world demonstration on TurtleBot](https://youtu.be/LplasPUm3jg) 6 | 7 | This codebase is based on the following publicly available git repositories: 8 | - TD3-BC: [sfujim/TD3_BC](https://github.com/sfujim/TD3_BC) 9 | - Flower: [adap/flower](https://github.com/adap/flower) 10 | - Structure of simulation: [CharlieDinh/pFedMe](https://github.com/CharlieDinh/pFedMe) 11 | 12 | The Python packages required to train FEDORA are listed in `requirements.txt` 13 | 14 | Our experiments are performed on Python 3.8 in a Ubuntu Linux environment. 15 | 16 | 17 | ## Federated learning with Flower 18 | 19 | Directory: `fed_flwr/` 20 | 21 | Specify federation parameters in `config/s_config.yml` amd `config/c_config.yml` 22 | Specify client learning parameters in `config/c_config.yml` 23 | 24 | a. Launch server and clients individually 25 | 26 | 1. Launch server 27 | python rl_server.py 28 | 29 | 2. Launch client (repeat for each client) 30 | python rl_client.py --gpu-index --eval-env --start-index --stop-index 31 | 32 | where 33 | gpu-index: index of CUDA device for PyTorch training 34 | eval-env: name of the D4RL data-set source for this client 35 | start-index: index of D4RL data-set to begin gathering data at 36 | stop-index: index of D4RL data-set to stop gathering data at 37 | 38 | b. To simplify this process, we share an example shell script with defalt parameters. 39 | 40 | 1. Specify client arguments in run_FEDORA.sh 41 | 2. Launch the shell script 42 | bash run_FEDORA.sh 43 | 44 | The tensorboard logs will be saved in a folder called 'Results' 45 | 46 | 47 | ## Federated learning simulation 48 | 49 | Run single-threaded simulation of FEDORA. Helpful in training with limited computing resoures. 50 | 51 | Directory: `fed_sim/` 52 | 53 | Launch main program (for default parameters, simply execute `python rl_main.py`) 54 | 55 | 56 | python rl_main.py --env-1 --env-2 --gpu-index-1 --gpu-index-2 --n-clients --ncpr \ 57 | --n-rounds --dataset-size --seed --batch-size --alpha-0 --alpha-1 --alpha-2 \ 58 | --local-epochs --temp-a --temp-a --decay-rate 59 | 60 | where 61 | env-1: name of the D4RL data-set source for first half clients 62 | env-2: name of the D4RL data-set source for second half clients 63 | gpu-index-1: index of CUDA device for training first half clients 64 | gpu-index-2: index of CUDA device for training first second clients 65 | n-clients: total number of clients 66 | ncpr: number of clients participating in a round of federation 67 | n-rounds: total rounds of federation 68 | dataset-size: size of a client's data-set 69 | seed: random number seed 70 | and the others are FEDORA hyperparameters. 71 | 72 | The tensorboard logs will be saved in a folder called 'Results' 73 | 74 | -------------------------------------------------------------------------------- /fed_sim/rl_main.py: -------------------------------------------------------------------------------- 1 | D4RL_SUPPRESS_IMPORT_ERROR=1 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import argparse 5 | import random 6 | from FLAlgorithms.servers.rlserver import RLFedServer 7 | import torch 8 | import d4rl 9 | import gym 10 | 11 | """ 12 | Server: 13 | Critic: Average Critics of all selected Users 14 | Actor: Average Actors of all selected Users 15 | User: 16 | Critic: Init policy to federated ciritc, keep a copy of federated critic, use this to take max amongst the current Q and the federated Q, use alpha_1 to prevent deviatio 17 | Actor: Init policy to federated actor, keep a copy of federated actor, use it to prevent deviation, with alpha_0, keep a copy of previous actor, 18 | use it to prevent deviation with alpha_2 19 | Federation Weights of Actor and Critic: 20 | According to the eval of the current policy on the local dataset 21 | Decay: 22 | Decay the weight of TD3-BC according to policy_val and server_val 23 | """ 24 | 25 | def main(datasets,num_users,num_users_per_round,batch_size,alpha_0,alpha_1,alpha_2,local_epochs,global_iters,dataset_size,seed,temp_a,temp_c,decay_rate,gpu_index_1,gpu_index_2): 26 | torch.manual_seed(seed) 27 | np.random.seed(seed) 28 | server = RLFedServer(datasets,num_users,num_users_per_round,batch_size,alpha_0,alpha_1,alpha_2,local_epochs,global_iters,dataset_size,seed,temp_a,temp_c,decay_rate,gpu_index_1,gpu_index_2) 29 | server.train() 30 | torch.cuda.empty_cache() 31 | 32 | 33 | if __name__ == "__main__": 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument("--env-1", type=str, default="hopper-expert-v2") 36 | parser.add_argument("--env-2", type=str, default="hopper-medium-v2") 37 | parser.add_argument("--n-clients", type=int, default=10) 38 | parser.add_argument("--ncpr", type=int, default=10,help="Number of clients per round") 39 | parser.add_argument("--batch-size", type=int, default=256) 40 | parser.add_argument("--alpha-0", type=float, default=1.0) #Prox to fed policy 41 | parser.add_argument("--alpha-1", type=float, default=0.0) #Prox of critic 42 | parser.add_argument("--alpha-2", type=float, default=1.0) #Prox to prev policy 43 | parser.add_argument("--local-epochs", type=int, default=20) 44 | parser.add_argument("--n-rounds", type=int, default=1000) 45 | parser.add_argument("--dataset-size", type=int, default=5000) 46 | parser.add_argument("--seed", type=int, default=1) 47 | parser.add_argument("--temp-a", type=float, default=0.1) 48 | parser.add_argument("--temp-c", type=float, default=0.1) 49 | parser.add_argument("--decay-rate", type=float, default=0.995) 50 | parser.add_argument("--gpu-index-1", type=int, default=0) 51 | parser.add_argument("--gpu-index-2", type=int, default=1) 52 | args = parser.parse_args() 53 | 54 | datasets = [args.env_1,args.env_2] 55 | main( 56 | datasets=datasets, 57 | num_users = args.n_clients, 58 | num_users_per_round=args.ncpr, 59 | batch_size=args.batch_size, 60 | alpha_0=args.alpha_0, 61 | alpha_1=args.alpha_1, 62 | alpha_2=args.alpha_2, 63 | local_epochs=args.local_epochs, 64 | global_iters=args.n_rounds, 65 | dataset_size= args.dataset_size, 66 | seed = args.seed, 67 | temp_a = args.temp_a, 68 | temp_c = args.temp_c, 69 | decay_rate = args.decay_rate, 70 | gpu_index_1 = args.gpu_index_1, 71 | gpu_index_2 = args.gpu_index_2 72 | ) 73 | -------------------------------------------------------------------------------- /fed_flwr/rl_server.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import gym 6 | import d4rl 7 | import yaml 8 | import datetime 9 | from torch.utils.tensorboard import SummaryWriter 10 | from utils.nets import Actor, Critic 11 | from utils.flwr_utils import aggregate_rl 12 | import flwr as fl 13 | from typing import Dict, List, Optional, Tuple, Union, OrderedDict 14 | from flwr.server.client_manager import ClientManager 15 | from flwr.server.client_proxy import ClientProxy 16 | from flwr.server.strategy.aggregate import aggregate 17 | from flwr.common.logger import log 18 | from flwr.common import ( 19 | EvaluateRes, 20 | FitIns, 21 | FitRes, 22 | Parameters, 23 | Scalar, 24 | parameters_to_ndarrays, 25 | ndarrays_to_parameters, 26 | ) 27 | 28 | 29 | class ServerFedRL: 30 | def __init__(self, c_config, s_config) -> None: 31 | self.seed = c_config["seed"] 32 | torch.manual_seed(self.seed) 33 | np.random.seed(self.seed) 34 | self.env_name = s_config["env_1"] 35 | server_env = gym.make(self.env_name) 36 | self.server_device = torch.device('cpu') 37 | state_dim = server_env.observation_space.shape[0] 38 | action_dim = server_env.action_space.shape[0] 39 | max_action = float(server_env.action_space.high[0]) 40 | self.server_actor = Actor(state_dim, action_dim, max_action).to(self.server_device) 41 | self.server_critic = Critic(state_dim, action_dim).to(self.server_device) 42 | self.temp_a = s_config["temp_a"] 43 | self.temp_c = s_config["temp_c"] 44 | self.len_param_actor = len(self.server_actor.state_dict().keys()) 45 | self.len_param_critic = len(self.server_critic.state_dict().keys()) 46 | 47 | 48 | def set_parameters_actor(self, params): 49 | params_dict = zip(self.server_actor.state_dict().keys(), params) 50 | state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) 51 | self.server_actor.load_state_dict(state_dict, strict=True) 52 | 53 | 54 | def set_parameters_critic(self, params): 55 | params_dict = zip(self.server_critic.state_dict().keys(), params) 56 | state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) 57 | self.server_critic.load_state_dict(state_dict, strict=True) 58 | 59 | 60 | def set_parameters(self, params): 61 | if len(params) != self.len_param_actor + self.len_param_critic: 62 | raise SystemExit("Error: Actor and Critic parameter length mismatch.") 63 | param_actor = params[:self.len_param_actor] 64 | param_critic = params[self.len_param_actor:] 65 | self.set_parameters_actor(param_actor) 66 | self.set_parameters_critic(param_critic) 67 | 68 | 69 | ## Policy Evaluation ## 70 | ## Given a policy, run its evaluation ## 71 | def select_action(self, state): 72 | state = torch.FloatTensor(state.reshape(1, -1)).to(self.server_device) 73 | return self.server_actor(state).cpu().data.numpy().flatten() 74 | 75 | 76 | def eval_server_policy(self, mean=0, std=1, seed_offset=0, eval_episodes=2): 77 | eval_env = gym.make(self.env_name) 78 | avg_reward = 0. 79 | for _ in range(eval_episodes): 80 | state, done = eval_env.reset(), False 81 | while not done: 82 | state = (np.array(state).reshape(1,-1) - mean)/std 83 | action = self.select_action(state) 84 | state, reward, done, _ = eval_env.step(action) 85 | avg_reward += reward 86 | 87 | avg_reward /= eval_episodes 88 | d4rl_score = eval_env.get_normalized_score(avg_reward) * 100 89 | print(f"Evaluation of Server avg_reward: {avg_reward:.3f},\ 90 | D4RL score: {d4rl_score:.3f}") 91 | return avg_reward, d4rl_score 92 | 93 | 94 | 95 | class CustomMetricStrategy(fl.server.strategy.FedAvg): 96 | 97 | def set_server(self, server: ServerFedRL): 98 | self.server = server 99 | 100 | 101 | def configure_fit( 102 | self, server_round: int, parameters: Parameters, client_manager: ClientManager 103 | ) -> List[Tuple[ClientProxy, FitIns]]: 104 | ret_sup = super().configure_fit(server_round, parameters, client_manager) 105 | s_rwd, s_d4rl_score = self.server.eval_server_policy() 106 | writer.add_scalar("s_rwd", s_rwd, server_round) 107 | # print("round {}, \ts_rwd {:.3f}"\ 108 | # .format(server_round, s_rwd)) 109 | return ret_sup 110 | 111 | 112 | def aggregate_fit( 113 | self, 114 | server_round: int, 115 | results: List[Tuple[ClientProxy, FitRes]], 116 | failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], 117 | ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: 118 | 119 | pol_val = [] 120 | for client in range(len(results)): 121 | res = results[client] 122 | pol_val.append(res[1].metrics["c_pol_val"]) 123 | c_id = res[0].cid.split("ipv6:")[-1] 124 | 125 | if not results: 126 | return None, {} 127 | # Do not aggregate if there are failures and failures are not accepted 128 | if not self.accept_failures and failures: 129 | return None, {} 130 | 131 | # Convert results 132 | weights_results = [ 133 | (parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples) 134 | for _, fit_res in results 135 | ] 136 | 137 | # Weigh and compute weights 138 | weights_updated = aggregate_rl(weights_results, pol_val, server.len_param_actor, \ 139 | server.temp_a, server.temp_c) 140 | parameters_aggregated = ndarrays_to_parameters(weights_updated) 141 | 142 | # Update server model weights 143 | self.server.set_parameters(weights_updated) 144 | 145 | metrics_aggregated = {} 146 | aggregated_weights = (parameters_aggregated, metrics_aggregated) 147 | 148 | return aggregated_weights 149 | 150 | 151 | def aggregate_evaluate( 152 | self, 153 | server_round: int, 154 | results: List[Tuple[ClientProxy, EvaluateRes]], 155 | failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]], 156 | ) -> Tuple[Optional[float], Dict[str, Scalar]]: 157 | 158 | avg_c_rwd = 0.0 159 | for client in range(len(results)): 160 | res = results[client] 161 | c_id = res[0].cid.split("ipv6:")[-1] 162 | c_rwd = res[1].metrics["c_rwd"] 163 | c_d4rl_score = res[1].metrics["c_d4rl_score"] 164 | c_decay = res[1].metrics["c_decay"] 165 | avg_c_rwd += c_rwd 166 | 167 | writer.add_scalar("c_rwd/" + c_id, \ 168 | c_rwd, server_round) 169 | # writer.add_scalar("c_d4rl_score/" + c_id, \ 170 | # c_d4rl_score, server_round) 171 | writer.add_scalar("c_decay/" + c_id, \ 172 | c_decay, server_round) 173 | 174 | avg_c_rwd /= len(results) 175 | writer.add_scalar("avg_c_rwd", avg_c_rwd, server_round) 176 | 177 | # s_rwd, s_d4rl_score = self.server.eval_server_policy() 178 | # writer.add_scalar("s_rwd", s_rwd, server_round) 179 | 180 | # print("round {}, \ts_rwd {:.3f}, \ts_d4rl_score {:.3f} \tavg_c_rwd {:.3f}"\ 181 | # .format(server_round, s_rwd, s_d4rl_score, avg_c_rwd)) 182 | 183 | print("round {}, \tavg_c_rwd {:.3f}"\ 184 | .format(server_round, avg_c_rwd)) 185 | 186 | return super().aggregate_evaluate(server_round, results, failures) 187 | 188 | 189 | 190 | if __name__ == "__main__": 191 | with open("config/c_config.yml", "r") as config_file: 192 | c_config = yaml.safe_load(config_file) 193 | with open("config/s_config.yml", "r") as config_file: 194 | s_config = yaml.safe_load(config_file) 195 | 196 | total_env = s_config["env_1"] + "_" + s_config["env_2"] 197 | run_id = "FEDORA_{}".format(datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")) 198 | log_path = "Results/" + total_env + "/" + run_id + "/" 199 | writer = SummaryWriter(log_path) 200 | writer.add_text("Num Rounds", str(s_config["n_rounds"])) 201 | writer.add_text("Num Clients", str(s_config["n_clients"])) 202 | writer.add_text("Num Clients per Round", str(s_config["ncpr"])) 203 | writer.add_text("Local Epochs", str(c_config["local_epochs"])) 204 | writer.add_text("Alpha_0", str(c_config["alpha_0"])) 205 | writer.add_text("Alpha_1", str(c_config["alpha_1"])) 206 | writer.add_text("Alpha_2", str(c_config["alpha_2"])) 207 | writer.add_text("Temp A", str(s_config["temp_a"])) 208 | writer.add_text("Temp C", str(s_config["temp_c"])) 209 | writer.add_text("Seed", str(c_config["seed"])) 210 | writer.add_text("decay_rate", str(c_config["decay_rate"])) 211 | 212 | fraction_c = s_config["ncpr"] / s_config["n_clients"] 213 | min_c = s_config["ncpr"] 214 | num_c = s_config["n_clients"] 215 | 216 | server = ServerFedRL(c_config, s_config) 217 | 218 | strategy = CustomMetricStrategy( 219 | fraction_fit = fraction_c, 220 | fraction_evaluate = fraction_c, 221 | min_fit_clients = min_c, 222 | min_evaluate_clients = min_c, 223 | min_available_clients = num_c, 224 | ) 225 | 226 | strategy.set_server(server) 227 | 228 | hist = fl.server.start_server( 229 | config = fl.server.ServerConfig(num_rounds=s_config["n_rounds"]), 230 | strategy = strategy, 231 | server_address = c_config["server_ip"] 232 | ) -------------------------------------------------------------------------------- /fed_sim/FLAlgorithms/users/userrl.py: -------------------------------------------------------------------------------- 1 | D4RL_SUPPRESS_IMPORT_ERROR=1 2 | import torch 3 | import os 4 | import numpy as np 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import gym 8 | import copy 9 | from utils.rl_utils import * 10 | import d4rl 11 | 12 | ### Defining the actor and critic models ### 13 | class Actor(nn.Module): 14 | def __init__(self, state_dim, action_dim, max_action): 15 | super(Actor, self).__init__() 16 | 17 | self.l1 = nn.Linear(state_dim, 256) 18 | self.l2 = nn.Linear(256, 256) 19 | self.l3 = nn.Linear(256, action_dim) 20 | 21 | self.max_action = max_action 22 | 23 | 24 | def forward(self, state): 25 | a = F.relu(self.l1(state)) 26 | a = F.relu(self.l2(a)) 27 | return self.max_action * torch.tanh(self.l3(a)) 28 | 29 | 30 | class Critic(nn.Module): 31 | def __init__(self, state_dim, action_dim): 32 | super(Critic, self).__init__() 33 | 34 | # Q1 architecture 35 | self.l1 = nn.Linear(state_dim + action_dim, 256) 36 | self.l2 = nn.Linear(256, 256) 37 | self.l3 = nn.Linear(256, 1) 38 | 39 | # Q2 architecture 40 | self.l4 = nn.Linear(state_dim + action_dim, 256) 41 | self.l5 = nn.Linear(256, 256) 42 | self.l6 = nn.Linear(256, 1) 43 | 44 | 45 | def forward(self, state, action): 46 | sa = torch.cat([state, action], 1) 47 | 48 | q1 = F.relu(self.l1(sa)) 49 | q1 = F.relu(self.l2(q1)) 50 | q1 = self.l3(q1) 51 | 52 | q2 = F.relu(self.l4(sa)) 53 | q2 = F.relu(self.l5(q2)) 54 | q2 = self.l6(q2) 55 | return q1, q2 56 | 57 | 58 | def Q1(self, state, action): 59 | sa = torch.cat([state, action], 1) 60 | 61 | q1 = F.relu(self.l1(sa)) 62 | q1 = F.relu(self.l2(q1)) 63 | q1 = self.l3(q1) 64 | return q1 65 | 66 | 67 | class UserFedRL: 68 | def __init__( 69 | self, 70 | userid=1, 71 | gpu_index=0, 72 | eval_env="hopper-expert-v0", 73 | start_index=0, 74 | stop_index=2000, 75 | alpha_0=0, 76 | alpha_1=0, 77 | alpha_2=0, 78 | batch_size=256, 79 | seed=0, 80 | discount=0.99, 81 | tau=0.005, 82 | policy_noise=0.2, 83 | noise_clip=0.5, 84 | policy_freq=2, 85 | alpha=2.5, 86 | decay_rate=0.995 87 | ): 88 | self.userid = userid 89 | self.eval_env = eval_env 90 | self.alpha_0 = alpha_0 91 | self.alpha_1 = alpha_1 92 | self.alpha_2 = alpha_2 93 | self.batch_size = batch_size 94 | self.seed = seed 95 | env = gym.make(self.eval_env) 96 | state_dim = env.observation_space.shape[0] 97 | action_dim = env.action_space.shape[0] 98 | max_action = float(env.action_space.high[0]) 99 | self.device = torch.device('cuda', index=gpu_index) if torch.cuda.is_available() else torch.device('cpu') 100 | 101 | self.actor = Actor(state_dim, action_dim, max_action).to(self.device) 102 | self.actor_target = copy.deepcopy(self.actor) 103 | self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=3e-4) 104 | self.prev_actor = copy.deepcopy(self.actor) ### 105 | 106 | self.critic = Critic(state_dim, action_dim).to(self.device) 107 | self.critic_target = copy.deepcopy(self.critic) 108 | self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=3e-4) 109 | 110 | self.max_action = max_action 111 | self.discount = discount 112 | self.tau = tau 113 | self.policy_noise = policy_noise 114 | self.noise_clip = noise_clip 115 | self.policy_freq = policy_freq 116 | self.alpha = alpha 117 | self.total_it = 0 118 | self.pol_val = 0 119 | self.server_val = 0 120 | self.decay = 1 121 | self.decay_rate = decay_rate 122 | 123 | self.replay_buffer = ReplayBuffer(state_dim, action_dim,gpu_index=gpu_index) 124 | dataset = d4rl.qlearning_dataset(env) 125 | self.replay_buffer.convert_D4RL(dataset,lower_lim=start_index,upper_lim=stop_index) 126 | 127 | def get_parameters_actor(self): 128 | for param in self.actor.parameters(): 129 | param.detach() 130 | return self.actor.parameters() 131 | 132 | def get_parameters_critic(self): 133 | for param in self.critic.parameters(): 134 | param.detach() 135 | return self.critic.parameters() 136 | 137 | 138 | def set_parameters_actor(self,server_actor): 139 | for old_param, new_param in zip(self.actor.parameters(), server_actor.parameters()): 140 | old_param.data = new_param.data.clone().to(self.device) 141 | 142 | def set_parameters_critic(self,server_critic): 143 | for old_param, new_param in zip(self.critic.parameters(), server_critic.parameters()): 144 | old_param.data = new_param.data.clone().to(self.device) 145 | 146 | 147 | def train(self,local_epochs,server_actor,server_critic): 148 | server_actor = server_actor.to(self.device) 149 | server_critic = server_critic.to(self.device) 150 | total_epochs = (self.replay_buffer.size // self.batch_size) * local_epochs 151 | self.server_val = self.eval_pol(server_actor,server_critic) 152 | for eph in range(total_epochs): 153 | self.train_TD3(server_actor,server_critic) 154 | self.pol_val = self.eval_pol(self.actor,self.critic) #value of the current policy according to the current dataset 155 | self.prev_actor = copy.deepcopy(self.actor) ### 156 | if self.server_val > self.pol_val: 157 | self.decay = self.decay * self.decay_rate 158 | 159 | 160 | def eval_pol(self,actor,critic): 161 | state, action, next_state, reward, not_done = self.replay_buffer.sample(self.replay_buffer.size) 162 | with torch.no_grad(): 163 | pi = actor(state) 164 | pol_val = critic.Q1(state, pi).mean().cpu().numpy() 165 | return pol_val 166 | 167 | 168 | 169 | def select_action(self, state): 170 | state = torch.FloatTensor(state.reshape(1, -1)).to(self.device) 171 | return self.actor(state).cpu().data.numpy().flatten() 172 | 173 | 174 | def train_TD3(self,server_actor,server_critic): 175 | self.total_it += 1 176 | # Sample replay buffer 177 | state, action, next_state, reward, not_done = self.replay_buffer.sample(self.batch_size) 178 | 179 | with torch.no_grad(): 180 | # Select action according to policy and add clipped noise 181 | noise = ( 182 | torch.randn_like(action) * self.policy_noise 183 | ).clamp(-self.noise_clip, self.noise_clip) 184 | 185 | next_action = ( 186 | self.actor_target(next_state) + noise 187 | ).clamp(-self.max_action, self.max_action) 188 | #Computing the minimum of the server Qs, to take min for the target update 189 | fed_Q1, fed_Q2 = server_critic(next_state,next_action) 190 | # Compute the target Q value 191 | target_Q1, target_Q2 = self.critic_target(next_state, next_action) 192 | target_Q = torch.min(target_Q1, target_Q2) 193 | fed_min = torch.min(fed_Q1,fed_Q2) 194 | target_Q = torch.max(target_Q,fed_min) ## 195 | target_Q = reward + not_done * self.discount * target_Q 196 | 197 | # Get current Q estimates 198 | current_Q1, current_Q2 = self.critic(state, action) 199 | with torch.no_grad(): 200 | fed_Q1, fed_Q2 = server_critic(state,action) 201 | 202 | 203 | # Compute critic loss 204 | critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q) + self.alpha_1 * F.mse_loss(current_Q1, fed_Q1) + self.alpha_1 * F.mse_loss(current_Q2, fed_Q2) 205 | # Optimize the critic 206 | self.critic_optimizer.zero_grad() 207 | critic_loss.backward() 208 | self.critic_optimizer.step() 209 | # Delayed policy updates 210 | if self.total_it % self.policy_freq == 0: 211 | 212 | # Compute actor loss 213 | pi = self.actor(state) 214 | server_pi = server_actor(state).detach() 215 | prev_pi = self.prev_actor(state).detach() ### 216 | Q = self.critic.Q1(state, pi) 217 | lmbda = self.alpha/Q.abs().mean().detach() 218 | actor_loss = -lmbda * self.decay * Q.mean() + self.decay * F.mse_loss(pi, action) + self.alpha_0 * F.mse_loss(pi,server_pi) + self.alpha_2 * F.mse_loss(pi,prev_pi) ### 219 | # Optimize the actor 220 | self.actor_optimizer.zero_grad() 221 | actor_loss.backward() 222 | self.actor_optimizer.step() 223 | 224 | # Update the frozen target models 225 | for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()): 226 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) 227 | 228 | for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()): 229 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) 230 | 231 | 232 | 233 | ## Policy Evaluation ## 234 | ## Given a policy, run its evaluation ## 235 | def eval_policy(self, mean=0, std=1, seed_offset=0, eval_episodes=2): 236 | eval_env = gym.make(self.eval_env) 237 | eval_env.seed(self.seed) 238 | avg_reward = 0. 239 | for _ in range(eval_episodes): 240 | state, done = eval_env.reset(), False 241 | while not done: 242 | state = (np.array(state).reshape(1,-1) - mean)/std 243 | action = self.select_action(state) 244 | state, reward, done, _ = eval_env.step(action) 245 | avg_reward += reward 246 | 247 | avg_reward /= eval_episodes 248 | d4rl_score = eval_env.get_normalized_score(avg_reward) * 100 249 | print(f"Evaluation of Client {self.userid} avg_reward: {avg_reward:.3f}, D4RL score: {d4rl_score:.3f} dataset: {self.eval_env} decay: {self.decay:.3f}") 250 | return avg_reward 251 | 252 | 253 | 254 | 255 | 256 | 257 | -------------------------------------------------------------------------------- /fed_sim/FLAlgorithms/servers/rlserver.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import numpy as np 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import gym 7 | import copy 8 | from FLAlgorithms.users.userrl import UserFedRL 9 | from torch.utils.tensorboard import SummaryWriter 10 | import datetime 11 | 12 | ### Defining the actor and critic models ### 13 | class Actor(nn.Module): 14 | def __init__(self, state_dim, action_dim, max_action): 15 | super(Actor, self).__init__() 16 | 17 | self.l1 = nn.Linear(state_dim, 256) 18 | self.l2 = nn.Linear(256, 256) 19 | self.l3 = nn.Linear(256, action_dim) 20 | 21 | self.max_action = max_action 22 | 23 | 24 | def forward(self, state): 25 | a = F.relu(self.l1(state)) 26 | a = F.relu(self.l2(a)) 27 | return self.max_action * torch.tanh(self.l3(a)) 28 | 29 | 30 | class Critic(nn.Module): 31 | def __init__(self, state_dim, action_dim): 32 | super(Critic, self).__init__() 33 | 34 | # Q1 architecture 35 | self.l1 = nn.Linear(state_dim + action_dim, 256) 36 | self.l2 = nn.Linear(256, 256) 37 | self.l3 = nn.Linear(256, 1) 38 | 39 | # Q2 architecture 40 | self.l4 = nn.Linear(state_dim + action_dim, 256) 41 | self.l5 = nn.Linear(256, 256) 42 | self.l6 = nn.Linear(256, 1) 43 | 44 | 45 | def forward(self, state, action): 46 | sa = torch.cat([state, action], 1) 47 | 48 | q1 = F.relu(self.l1(sa)) 49 | q1 = F.relu(self.l2(q1)) 50 | q1 = self.l3(q1) 51 | 52 | q2 = F.relu(self.l4(sa)) 53 | q2 = F.relu(self.l5(q2)) 54 | q2 = self.l6(q2) 55 | return q1, q2 56 | 57 | 58 | def Q1(self, state, action): 59 | sa = torch.cat([state, action], 1) 60 | 61 | q1 = F.relu(self.l1(sa)) 62 | q1 = F.relu(self.l2(q1)) 63 | q1 = self.l3(q1) 64 | return q1 65 | 66 | 67 | ## Defining the RL federation server ### 68 | 69 | class RLFedServer: 70 | def __init__(self,datasets,num_users,num_users_per_round,batch_size,alpha_0,alpha_1,alpha_2,local_epochs,global_iters,dataset_size,seed,temp_a,temp_c,decay_rate,gpu_index_1,gpu_index_2): 71 | self.env = datasets[0] 72 | self.num_users = num_users 73 | self.num_users_per_round = num_users_per_round 74 | server_env = gym.make(datasets[0]) 75 | self.server_device = torch.device('cpu') 76 | state_dim = server_env.observation_space.shape[0] 77 | action_dim = server_env.action_space.shape[0] 78 | max_action = float(server_env.action_space.high[0]) 79 | ## Creating server versions of Actor and Critic 80 | self.server_actor = Actor(state_dim, action_dim, max_action).to(self.server_device) 81 | self.server_critic = Critic(state_dim, action_dim).to(self.server_device) 82 | self.users = [] 83 | self.total_train_samples = 0 84 | self.dataset_size = dataset_size 85 | self.local_epochs = local_epochs 86 | self.global_iter = global_iters 87 | total_env = datasets[0] + "_" + datasets[1] 88 | run_id = "FEDORA_{}".format(datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")) 89 | log_path = "Results/" + total_env + "/" + run_id + "/" 90 | self.writer = SummaryWriter(log_path) 91 | self.temp_a = temp_a 92 | self.temp_c = temp_c 93 | self.writer.add_text("Num Clients",str(self.num_users)) 94 | self.writer.add_text("Num Clients per Round",str(self.num_users_per_round)) 95 | self.writer.add_text("Each User Samples",str(self.dataset_size)) 96 | self.writer.add_text("Local Epochs",str(self.local_epochs)) 97 | self.writer.add_text("Alpha_0",str(alpha_0)) 98 | self.writer.add_text("Alpha_1",str(alpha_1)) 99 | self.writer.add_text("Alpha_2",str(alpha_2)) 100 | self.writer.add_text("Temp A",str(temp_a)) 101 | self.writer.add_text("Temp C",str(temp_c)) 102 | self.writer.add_text("Seed",str(seed)) 103 | self.writer.add_text("decay_rate",str(decay_rate)) 104 | 105 | kwargs = { 106 | "discount": 0.99, 107 | "tau": 0.005, 108 | "policy_noise": 0.2 * max_action, 109 | "noise_clip": 0.5 * max_action, 110 | "policy_freq": 2, 111 | "alpha": 2.5, 112 | "seed":seed, 113 | "alpha_0":alpha_0, 114 | "alpha_1":alpha_1, 115 | "alpha_2":alpha_2, 116 | "batch_size":batch_size, 117 | "decay_rate":decay_rate 118 | } 119 | 120 | 121 | 122 | ### Creating clients with different datasets ### 123 | for i in range(self.num_users // 2): 124 | print("Creating client ",i) 125 | print("Env: " + datasets[0]) 126 | self.writer.add_text("Env/"+str(i),str(datasets[0])) 127 | print("Start Index", i*dataset_size) 128 | print("Stop Index", (i+1)*dataset_size) 129 | user = UserFedRL(userid=i,gpu_index=gpu_index_1,eval_env=datasets[0], start_index=i*dataset_size, stop_index=(i+1)*dataset_size,**kwargs) 130 | self.users.append(user) 131 | print("Replay Buffer Size",user.replay_buffer.size) 132 | print("="* 20) 133 | self.total_train_samples += user.replay_buffer.size 134 | 135 | for i in range(self.num_users // 2,self.num_users): 136 | print("Creating client ",i) 137 | print("Env " + datasets[1]) 138 | self.writer.add_text("Env/"+str(i),str(datasets[1])) 139 | print("Start Index", i*dataset_size) 140 | print("Stop Index", (i+1)*dataset_size) 141 | user = UserFedRL(userid=i,gpu_index=gpu_index_2,eval_env=datasets[1], start_index=i*dataset_size, stop_index=(i+1)*dataset_size,**kwargs) 142 | self.users.append(user) 143 | print("Replay Buffer Size",user.replay_buffer.size) 144 | print("="* 20) 145 | self.total_train_samples += user.replay_buffer.size 146 | 147 | def send_parameters_actor(self): 148 | assert (self.users is not None and len(self.users) > 0) 149 | for user in self.users: 150 | user.set_parameters_actor(self.server_actor) 151 | 152 | def send_parameters_critic(self): 153 | assert (self.users is not None and len(self.users) > 0) 154 | for user in self.users: 155 | user.set_parameters_critic(self.server_critic) 156 | 157 | 158 | def add_parameters_actor(self, user, ratio): 159 | for server_actor_param, user_actor_param in zip(self.server_actor.parameters(), user.get_parameters_actor()): 160 | server_actor_param.data = server_actor_param.data + user_actor_param.data.cpu().clone() * ratio 161 | 162 | 163 | def add_parameters_critic(self, user, ratio): 164 | for server_critic_param, user_critic_param in zip(self.server_critic.parameters(), user.get_parameters_critic()): 165 | server_critic_param.data = server_critic_param.data + user_critic_param.data.cpu().clone() * ratio 166 | 167 | 168 | def select_users(self, round, num_users): 169 | if(num_users == len(self.users)): 170 | print("All users are selected") 171 | return self.users 172 | 173 | num_users = min(num_users, len(self.users)) 174 | return np.random.choice(self.users, num_users, replace=False) 175 | 176 | 177 | def aggregate_actor_parameters(self,glob_iter): 178 | assert (self.users is not None and len(self.users) > 0) 179 | for param in self.server_actor.parameters(): 180 | param.data = torch.zeros_like(param.data) 181 | total_train = 0 182 | 183 | for user in self.selected_users: 184 | total_train += np.exp(self.temp_a * user.pol_val) 185 | for user in self.selected_users: 186 | ratio = np.exp(self.temp_a * user.pol_val) / total_train 187 | print(ratio) 188 | self.writer.add_scalar("c_ratio/"+str(user.userid),ratio,glob_iter+1) 189 | self.add_parameters_actor(user,ratio) 190 | 191 | def aggregate_critic_parameters(self): #### 192 | assert (self.users is not None and len(self.users) > 0) ##### 193 | for param in self.server_critic.parameters(): #### 194 | param.data = torch.zeros_like(param.data) ##### 195 | total_train = 0 196 | for user in self.selected_users: 197 | total_train += np.exp(self.temp_c * user.pol_val) 198 | for user in self.selected_users: 199 | ratio = np.exp(self.temp_c * user.pol_val) / total_train 200 | self.add_parameters_critic(user, ratio) 201 | 202 | def train(self): 203 | 204 | for glob_iter in range(self.global_iter): 205 | avg_rwd = [] 206 | print("-------------Round number: ",glob_iter, " -------------") 207 | self.send_parameters_actor() 208 | self.send_parameters_critic() 209 | global_reward = self.eval_server_policy() 210 | self.writer.add_scalar("s_rwd",global_reward,glob_iter+1) 211 | self.selected_users = self.select_users(glob_iter,self.num_users_per_round) 212 | for user in self.selected_users: 213 | server_actor_clone = copy.deepcopy(self.server_actor) 214 | server_critic_clone = copy.deepcopy(self.server_critic) 215 | user.train(self.local_epochs,server_actor_clone,server_critic_clone) 216 | user_reward = user.eval_policy() 217 | avg_rwd.append(user_reward) 218 | self.writer.add_scalar("c_rwd/"+str(user.userid),user_reward,glob_iter+1) 219 | self.writer.add_scalar("c_decay/"+str(user.userid),user.decay,glob_iter+1) 220 | print('Average reward over client:',np.mean(avg_rwd)) 221 | self.writer.add_scalar("avg_c_rwd",np.mean(avg_rwd),glob_iter+1) 222 | self.aggregate_actor_parameters(glob_iter=glob_iter) 223 | self.aggregate_critic_parameters() 224 | 225 | ## Policy Evaluation ## 226 | ## Given a policy, run its evaluation ## 227 | def select_action(self, state): 228 | state = torch.FloatTensor(state.reshape(1, -1)).to(self.server_device) 229 | return self.server_actor(state).cpu().data.numpy().flatten() 230 | 231 | def eval_server_policy(self, mean=0, std=1, seed_offset=0, eval_episodes=2): 232 | eval_env = gym.make(self.env) 233 | avg_reward = 0. 234 | for _ in range(eval_episodes): 235 | state, done = eval_env.reset(), False 236 | while not done: 237 | state = (np.array(state).reshape(1,-1) - mean)/std 238 | action = self.select_action(state) 239 | state, reward, done, _ = eval_env.step(action) 240 | avg_reward += reward 241 | 242 | avg_reward /= eval_episodes 243 | d4rl_score = eval_env.get_normalized_score(avg_reward) * 100 244 | print(f"Evaluation of Server avg_reward: {avg_reward:.3f}, D4RL score: {d4rl_score:.3f}") 245 | return avg_reward -------------------------------------------------------------------------------- /fed_flwr/rl_client.py: -------------------------------------------------------------------------------- 1 | import flwr as fl 2 | import torch 3 | import numpy as np 4 | import torch.nn.functional as F 5 | from collections import OrderedDict 6 | import gym 7 | import d4rl 8 | import argparse 9 | import yaml 10 | import copy 11 | from utils.rl_utils import ReplayBuffer 12 | from utils.nets import Actor, Critic 13 | 14 | 15 | 16 | class ClientFedRL: 17 | 18 | def __init__(self, gpu_index, eval_env, start_index, stop_index, \ 19 | c_config): 20 | self.eval_env = eval_env 21 | self.seed = c_config["seed"] 22 | torch.manual_seed(self.seed) 23 | np.random.seed(self.seed) 24 | self.device = torch.device('cuda', index=gpu_index) \ 25 | if (torch.cuda.is_available() and gpu_index > -1) \ 26 | else torch.device('cpu') 27 | env = gym.make(self.eval_env) 28 | state_dim = env.observation_space.shape[0] 29 | action_dim = env.action_space.shape[0] 30 | max_action = float(env.action_space.high[0]) 31 | self.max_action = max_action 32 | self.alpha_0 = c_config["alpha_0"] 33 | self.alpha_1 = c_config["alpha_1"] 34 | self.alpha_2 = c_config["alpha_2"] 35 | self.batch_size = c_config["batch_size"] 36 | self.discount = c_config["discount"] 37 | self.tau = c_config["tau"] 38 | self.policy_noise = c_config["policy_noise_f"] * max_action 39 | self.noise_clip = c_config["noise_clip_f"] * max_action 40 | self.policy_freq = c_config["policy_freq"] 41 | self.alpha = c_config["alpha"] 42 | self.decay_rate =c_config["decay_rate"] 43 | self.l_r = c_config["l_r"] 44 | self.local_epochs = c_config["local_epochs"] 45 | 46 | self.total_it = 0 47 | self.pol_val = 0 48 | self.server_val = 0 49 | self.decay = 1 50 | 51 | self.actor = Actor(state_dim, action_dim, max_action).to(self.device) 52 | self.actor_target = copy.deepcopy(self.actor) 53 | self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=self.l_r) 54 | self.prev_actor = copy.deepcopy(self.actor) 55 | self.critic = Critic(state_dim, action_dim).to(self.device) 56 | self.critic_target = copy.deepcopy(self.critic) 57 | self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=self.l_r) 58 | self.len_param_actor = len(self.actor.state_dict().keys()) 59 | self.len_param_critic = len(self.critic.state_dict().keys()) 60 | 61 | self.replay_buffer = ReplayBuffer(state_dim, action_dim, gpu_index=gpu_index) 62 | dataset = d4rl.qlearning_dataset(env) 63 | self.replay_buffer.convert_D4RL(dataset, lower_lim=start_index, upper_lim=stop_index) 64 | 65 | 66 | def get_parameters_actor(self): 67 | return [val.cpu().numpy() for _, val in self.actor.state_dict().items()] 68 | 69 | 70 | def get_parameters_critic(self): 71 | return [val.cpu().numpy() for _, val in self.critic.state_dict().items()] 72 | 73 | 74 | def get_parameters_combined(self): 75 | param_actor = self.get_parameters_actor() 76 | param_critic = self.get_parameters_critic() 77 | param_combined = param_actor + param_critic 78 | return param_combined 79 | 80 | 81 | def set_parameters_actor(self, params): 82 | params_dict = zip(self.actor.state_dict().keys(), params) 83 | state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) 84 | self.actor.load_state_dict(state_dict, strict=True) 85 | 86 | 87 | def set_parameters_critic(self, params): 88 | params_dict = zip(self.critic.state_dict().keys(), params) 89 | state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) 90 | self.critic.load_state_dict(state_dict, strict=True) 91 | 92 | 93 | def set_parameters_combined(self, params): 94 | if len(params) != self.len_param_actor + self.len_param_critic: 95 | raise SystemExit("Error: Actor and Critic parameter length mismatch.") 96 | param_actor = params[:self.len_param_actor] 97 | param_critic = params[self.len_param_actor:] 98 | self.set_parameters_actor(param_actor) 99 | self.set_parameters_critic(param_critic) 100 | 101 | 102 | def eval_pol(self, actor, critic): 103 | state, action, next_state, reward, not_done = \ 104 | self.replay_buffer.sample(self.replay_buffer.size) 105 | with torch.no_grad(): 106 | pi = actor(state) 107 | pol_val = critic.Q1(state, pi).mean().cpu().numpy() 108 | return pol_val 109 | 110 | 111 | def select_action(self, state): 112 | state = torch.FloatTensor(state.reshape(1, -1)).to(self.device) 113 | return self.actor(state).cpu().data.numpy().flatten() 114 | 115 | 116 | def train(self): 117 | server_actor = copy.deepcopy(self.actor) 118 | server_critic = copy.deepcopy(self.critic) 119 | total_epochs = (self.replay_buffer.size // self.batch_size) * self.local_epochs 120 | self.server_val = self.eval_pol(server_actor, server_critic) 121 | for epoch in range(total_epochs): 122 | self.train_TD3(server_actor, server_critic) 123 | # Value of the current policy according to the current dataset 124 | self.pol_val = self.eval_pol(self.actor, self.critic).item() 125 | self.prev_actor = copy.deepcopy(self.actor) 126 | if self.server_val > self.pol_val: 127 | self.decay = self.decay * self.decay_rate 128 | 129 | 130 | def train_TD3(self, server_actor, server_critic): 131 | self.total_it += 1 132 | # Sample replay buffer 133 | state, action, next_state, reward, not_done = \ 134 | self.replay_buffer.sample(self.batch_size) 135 | 136 | with torch.no_grad(): 137 | # Select action according to policy and add clipped noise 138 | noise = ( 139 | torch.randn_like(action) * self.policy_noise 140 | ).clamp(-self.noise_clip, self.noise_clip) 141 | next_action = ( 142 | self.actor_target(next_state) + noise 143 | ).clamp(-self.max_action, self.max_action) 144 | # Computing the minimum of the server Qs, to take min for the target update 145 | fed_Q1, fed_Q2 = server_critic(next_state,next_action) 146 | # Compute the target Q value 147 | target_Q1, target_Q2 = self.critic_target(next_state, next_action) 148 | target_Q = torch.min(target_Q1, target_Q2) 149 | fed_min = torch.min(fed_Q1,fed_Q2) 150 | target_Q = torch.max(target_Q,fed_min) 151 | target_Q = reward + not_done * self.discount * target_Q 152 | 153 | # Get current Q estimates 154 | current_Q1, current_Q2 = self.critic(state, action) 155 | with torch.no_grad(): 156 | fed_Q1, fed_Q2 = server_critic(state,action) 157 | 158 | # Compute critic loss 159 | critic_loss = F.mse_loss(current_Q1, target_Q) \ 160 | + F.mse_loss(current_Q2, target_Q) + self.alpha_1 \ 161 | * F.mse_loss(current_Q1, fed_Q1) + self.alpha_1 \ 162 | * F.mse_loss(current_Q2, fed_Q2) 163 | # Optimize the critic 164 | self.critic_optimizer.zero_grad() 165 | critic_loss.backward() 166 | self.critic_optimizer.step() 167 | # Delayed policy updates 168 | if self.total_it % self.policy_freq == 0: 169 | 170 | # Compute actor loss 171 | pi = self.actor(state) 172 | server_pi = server_actor(state).detach() 173 | prev_pi = self.prev_actor(state).detach() 174 | Q = self.critic.Q1(state, pi) 175 | lmbda = self.alpha/Q.abs().mean().detach() 176 | actor_loss = -lmbda * self.decay * Q.mean() \ 177 | + self.decay * F.mse_loss(pi, action) \ 178 | + self.alpha_0 * F.mse_loss(pi,server_pi) \ 179 | + self.alpha_2 * F.mse_loss(pi,prev_pi) 180 | # Optimize the actor 181 | self.actor_optimizer.zero_grad() 182 | actor_loss.backward() 183 | self.actor_optimizer.step() 184 | 185 | # Update the frozen target models 186 | for param, target_param in \ 187 | zip(self.critic.parameters(), self.critic_target.parameters()): 188 | target_param.data.copy_(self.tau * param.data \ 189 | + (1 - self.tau) * target_param.data) 190 | 191 | for param, target_param in \ 192 | zip(self.actor.parameters(), self.actor_target.parameters()): 193 | target_param.data.copy_(self.tau * param.data \ 194 | + (1 - self.tau) * target_param.data) 195 | 196 | 197 | ## Policy Evaluation ## 198 | ## Given a policy, run its evaluation ## 199 | def eval_policy(self, mean=0, std=1, seed_offset=0, eval_episodes=2): 200 | eval_env = gym.make(self.eval_env) 201 | eval_env.seed(self.seed) 202 | avg_reward = 0. 203 | for _ in range(eval_episodes): 204 | state, done = eval_env.reset(), False 205 | while not done: 206 | state = (np.array(state).reshape(1,-1) - mean)/std 207 | action = self.select_action(state) 208 | state, reward, done, _ = eval_env.step(action) 209 | avg_reward += reward 210 | 211 | avg_reward /= eval_episodes 212 | d4rl_score = eval_env.get_normalized_score(avg_reward) * 100 213 | # print(f"Evaluation of client ; c_rwd: {avg_reward:.3f}, \ 214 | # D4RL score: {d4rl_score:.3f} dataset: {self.eval_env} c_decay: {self.decay:.3f}") 215 | return avg_reward, d4rl_score, self.decay 216 | 217 | 218 | 219 | class NumPyClientRL(fl.client.NumPyClient): 220 | def get_parameters(self, config=None): 221 | return client.get_parameters_combined() 222 | 223 | 224 | def set_parameters(self, params): 225 | client.set_parameters_combined(params) 226 | 227 | 228 | def fit(self, params, config=None): 229 | self.set_parameters(params) 230 | client.train() 231 | c_pol_val = client.pol_val 232 | return self.get_parameters(), 1, {"c_pol_val":c_pol_val} 233 | 234 | 235 | def evaluate(self, params, config=None): 236 | #self.set_parameters(params) 237 | c_rwd, c_d4rl_score, c_decay = client.eval_policy() 238 | return 0.0, 1, {"c_rwd":c_rwd, "c_d4rl_score":c_d4rl_score, \ 239 | "c_decay":c_decay} 240 | 241 | 242 | 243 | if __name__ == "__main__": 244 | 245 | parser = argparse.ArgumentParser() 246 | parser.add_argument("--gpu-index", type=int, default=-1, \ 247 | help="GPU index for training, default:CPU") 248 | parser.add_argument("--eval-env", type=str, default="hopper-expert-v0", \ 249 | help="client gym environment") 250 | parser.add_argument("--start-index", type=int, default=0, \ 251 | help="start index of d4rl sample") 252 | parser.add_argument("--stop-index", type=int, default=2000, \ 253 | help="stop index of d4rl sample") 254 | args = parser.parse_args() 255 | 256 | with open("config/c_config.yml", "r") as config_file: 257 | c_config = yaml.safe_load(config_file) 258 | 259 | client = ClientFedRL(args.gpu_index, args.eval_env, \ 260 | args.start_index, args.stop_index, c_config) 261 | 262 | fl.client.start_numpy_client(server_address=c_config["server_ip"], \ 263 | client=NumPyClientRL()) --------------------------------------------------------------------------------