├── torchrl.zip ├── constopt-pytorch.zip ├── supplementary material of pami.pdf ├── readme.txt ├── single AC.py ├── distral.py ├── gradsur.py ├── SAC.py ├── soft-module.py └── CAMRL.py /torchrl.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huanghanchi/CAMRL/HEAD/torchrl.zip -------------------------------------------------------------------------------- /constopt-pytorch.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huanghanchi/CAMRL/HEAD/constopt-pytorch.zip -------------------------------------------------------------------------------- /supplementary material of pami.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huanghanchi/CAMRL/HEAD/supplementary material of pami.pdf -------------------------------------------------------------------------------- /readme.txt: -------------------------------------------------------------------------------- 1 | 1. Gym-minigrid Env: https://github.com/maximecb/gym-minigrid 2 | Implementation of Mastering Rate based Curriculum Learning: https://github.com/lcswillems/automatic-curriculum 3 | 4 | 2. Meta-world Env: https://github.com/RchalYang/metaworld 5 | 6 | 3. Atari Env: https://github.com/mgbellemare/Arcade-Learning-Environment 7 | 8 | 4. Ravens Env: https://github.com/google-research/ravens 9 | 10 | 5. RLBench Env: https://github.com/stepjam/RLBench 11 | 12 | 13 | Code of YOLOR: https://github.com/WongKinYiu/yolor 14 | Code of CAMRL: ./CAMRL.py 15 | Code of SAC: ./SAC.py 16 | Code of Distral: ./distral.py 17 | Code of Gradient Surgery: ./gradsur.py 18 | Code of Single Actor Critic: ./single AC.py 19 | Code of Soft Module: ./soft-module.py 20 | -------------------------------------------------------------------------------- /single AC.py: -------------------------------------------------------------------------------- 1 | import metaworld 2 | import random 3 | import argparse 4 | import gym 5 | import numpy as np 6 | from itertools import count 7 | from collections import namedtuple 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torch.optim as optim 12 | from torch.autograd import Variable 13 | from torch.distributions import Categorical 14 | import torch_ac 15 | 16 | class Single(nn.Module): 17 | 18 | def __init__(self, tasks = len(envs) ): 19 | super(Single, self).__init__() 20 | self.actor = torch.nn.ModuleList ( [ nn.Sequential( 21 | nn.Linear(12, 128), 22 | nn.ReLU(), 23 | nn.Linear(128, 64), 24 | nn.ReLU(), 25 | nn.Linear(64, 32), 26 | nn.ReLU(), 27 | nn.Linear(32, 16), 28 | nn.ReLU(), 29 | nn.Linear(16, 4) 30 | ) for i in range(tasks) ] ) 31 | self.value_head = torch.nn.ModuleList ( [nn.Sequential( 32 | nn.Linear(12, 128), 33 | nn.ReLU(), 34 | nn.Linear(128, 64), 35 | nn.ReLU(), 36 | nn.Linear(64, 32), 37 | nn.ReLU(), 38 | nn.Linear(32, 16), 39 | nn.ReLU(), 40 | nn.Linear(16, 1) 41 | ) for i in range(tasks) ] ) 42 | 43 | self.saved_actions = [[] for i in range(tasks)] 44 | self.rewards = [[] for i in range(tasks)] 45 | self.tasks = tasks 46 | 47 | def forward(self, x): 48 | tmp=[] 49 | for i in range(self.tasks) : 50 | tmp.append( F.softmax(self.actor[i](x) - self.actor[i](x).max())) 51 | state_values = self.value_head[index](x) 52 | return tmp, state_values 53 | 54 | def select_action(state, tasks, index): 55 | state = torch.tensor(list(state)).float() 56 | probs, state_value = model(Variable(state)) 57 | 58 | # Obtain the most probable action for each one of the policies 59 | actions = [] 60 | for i in range(tasks): 61 | model.saved_actions[i].append(SavedAction(probs[i].log().dot(probs[i]), state_value)) 62 | return probs, state_value 63 | 64 | def finish_episode(tasks, alpha, beta, gamma): 65 | 66 | ### Calculate loss function according to Equation 1 67 | R = 0 68 | saved_actions = model.saved_actions[index] 69 | policy_losses = [] 70 | value_losses = [] 71 | 72 | ## Obtain the discounted rewards backwards 73 | rewards = [] 74 | for r in model.rewards[index][::-1]: 75 | R = r + gamma * R 76 | rewards.insert(0, R) 77 | 78 | ## Standardize the rewards to be unit normal (to control the gradient estimator variance) 79 | rewards = torch.Tensor(rewards) 80 | rewards = (rewards - rewards.mean()) / (rewards.std() + np.finfo(np.float32).eps) 81 | 82 | for (log_prob, value), r in zip(saved_actions, rewards): 83 | reward = r - value.data[0] 84 | policy_losses.append(-log_prob * reward) 85 | value_losses.append(F.smooth_l1_loss(value, Variable(torch.Tensor([r])))) 86 | 87 | optimizer.zero_grad() 88 | loss = torch.stack(policy_losses).sum() + torch.stack(value_losses).sum() 89 | loss.backward() 90 | optimizer.step() 91 | 92 | # Clean memory 93 | for i in range(tasks): 94 | del model.rewards[i][:] 95 | del model.saved_actions[i][:] 96 | 97 | torch.manual_seed(1337) 98 | SavedAction = namedtuple('SavedAction', ['log_prob', 'value']) 99 | ml10 = metaworld.MT10() # Construct the benchmark, sampling tasks 100 | envs = [] 101 | for name, env_cls in ml10.train_classes.items(): 102 | env = env_cls() 103 | task = random.choice([task for task in ml10.train_tasks 104 | if task.env_name == name]) 105 | env.set_task(task) 106 | envs.append(env) 107 | 108 | file_name = "Single" 109 | batch_size = 128 110 | alpha = 0.5 111 | beta = 0.5 112 | gamma = 0.999 113 | is_plot = False 114 | num_episodes = 500 115 | max_num_steps_per_episode = 10000 116 | learning_rate = 0.001 117 | tasks = len(envs) 118 | rewardsRec = [[] for _ in range(len(envs))] 119 | model = Single( ) 120 | optimizer = optim.Adam(model.parameters(), lr=3e-2) 121 | 122 | for rnd in range(10000): 123 | for index, env in enumerate(envs): 124 | total_reward = 0 125 | state = env.reset() 126 | for t in range(200): # Don't infinite loop while learning 127 | probs, state_value = select_action(state, tasks, index ) 128 | state, reward, done, _ = env.step(probs[index].detach().numpy()) 129 | model.rewards[index].append(reward) 130 | total_reward += reward 131 | if done: 132 | break 133 | print(rnd, index, total_reward) 134 | rewardsRec[index].append(total_reward) 135 | finish_episode(tasks, alpha, beta, gamma ) 136 | np.save('mt10_single_rewardsRec.npy', rewardsRec) 137 | torch.save(model.state_dict(), 'mt10_single_params.pkl') 138 | -------------------------------------------------------------------------------- /distral.py: -------------------------------------------------------------------------------- 1 | import metaworld 2 | import random 3 | from torch.distributions.normal import Normal 4 | import math 5 | 6 | class parser: 7 | def __init__(self): 8 | self.gamma = 0.99 9 | self.alpha = 0.5 10 | self.beta = .5 11 | self.seed = 543 12 | self.render = False 13 | self.log_interval = 10 14 | self.envs = envs 15 | 16 | def normalized_columns_initializer(weights, std=1.0): 17 | out = torch.randn(weights.size()) 18 | out *= std / torch.sqrt(out.pow(2).sum(1).expand_as(out)) 19 | return out 20 | 21 | def weights_init(m): 22 | classname = m.__class__.__name__ 23 | if classname.find('Conv') != -1: 24 | weight_shape = list(m.weight.data.size()) 25 | fan_in = np.prod(weight_shape[1:4]) 26 | fan_out = np.prod(weight_shape[2:4]) * weight_shape[0] 27 | w_bound = np.sqrt(6. / (fan_in + fan_out)) 28 | m.weight.data.uniform_(-w_bound, w_bound) 29 | m.bias.data.fill_(0) 30 | elif classname.find('Linear') != -1: 31 | weight_shape = list(m.weight.data.size()) 32 | fan_in = weight_shape[1] 33 | fan_out = weight_shape[0] 34 | w_bound = np.sqrt(6. / (fan_in + fan_out)) 35 | m.weight.data.uniform_(-w_bound, w_bound) 36 | m.bias.data.fill_(0) 37 | 38 | class Policy(nn.Module): 39 | def __init__(self): 40 | super(Policy, self).__init__() 41 | self.num_envs = num_envs 42 | 43 | self.mu_heads = nn.ModuleList ( [ nn.Sequential( 44 | nn.Linear(12, 128), 45 | nn.ReLU(), 46 | nn.Linear(128, 64), 47 | nn.ReLU(), 48 | nn.Linear(64, 32), 49 | nn.ReLU(), 50 | nn.Linear(32, 16), 51 | nn.ReLU(), 52 | nn.Linear(16,4) 53 | ) for i in range(self.num_envs+1) ] ) 54 | self.sigma2_heads = nn.ModuleList ( [ nn.Sequential( 55 | nn.Linear(12, 128), 56 | nn.ReLU(), 57 | nn.Linear(128, 64), 58 | nn.ReLU(), 59 | nn.Linear(64, 32), 60 | nn.ReLU(), 61 | nn.Linear(32, 16), 62 | nn.ReLU(), 63 | nn.Linear(16,4) 64 | ) for i in range(self.num_envs+1) ] ) 65 | self.value_heads = nn.ModuleList([nn.Linear(16, 4) for i in range(self.num_envs)]) 66 | self.apply(weights_init) 67 | self.div = [[] for i in range(num_envs)] 68 | self.saved_actions = [[] for i in range(self.num_envs)] 69 | #self.entropies = [[] for i in range(num_envs)] 70 | self.entropies = [] 71 | self.rewards = [[] for i in range(self.num_envs)] 72 | 73 | def forward(self, y, index): 74 | '''updated to have 5 return values (2 for each action head one for 75 | value''' 76 | x = y 77 | mu = F.softmax(self.mu_heads[index](x),dim=-1)[0] 78 | sigma2 = self.sigma2_heads[index](x) 79 | sigma = F.softplus(sigma2) 80 | value = self.value_heads[index](x) 81 | mu_dist = F.softmax(self.mu_heads[-1](x),dim=-1)[0] 82 | sigma2_dist = self.sigma2_heads[-1](x) 83 | sigma_dist = F.softplus(sigma2_dist) 84 | return mu, sigma, value, mu_dist, sigma_dist 85 | 86 | def select_action(state, index): 87 | '''given a state, this function chooses the action to take 88 | arguments: state - observation matrix specifying the current model state 89 | env - integer specifying which environment to sample action 90 | for 91 | return - action to take''' 92 | 93 | state = Variable(state) 94 | mu, sigma, value, mu_t, sigma_t = model(state, index) 95 | 96 | prob = Normal(args.alpha*mu_t + args.beta*mu, args.alpha*sigma_t.sqrt() + \ 97 | args.beta*sigma.sqrt()) 98 | 99 | entropy = 0.5*(((args.alpha*sigma_t + args.beta*sigma)*2*pi).log() + 1) 100 | 101 | 102 | new_KL = torch.div(sigma_t.sqrt(),args.alpha*sigma_t.sqrt() + \ 103 | args.beta*sigma.sqrt()).log() + \ 104 | torch.div((args.alpha*sigma_t.sqrt() + \ 105 | args.beta*sigma.sqrt()).pow(2) + \ 106 | ((args.alpha-1)*mu_t+(args.beta)*mu).pow(2),(2*sigma_t)) - 0.5 107 | 108 | log_prob = prob.loc.log() 109 | model.saved_actions[index].append(SavedAction(log_prob, value)) 110 | model.entropies.append(entropy) 111 | model.div[index].append(new_KL) 112 | 113 | return prob.loc 114 | 115 | def finish_episode(): 116 | policy_losses = [] 117 | value_losses = [] 118 | entropy_sum = 0 119 | loss = torch.zeros(1, 1) 120 | loss = Variable(loss) 121 | 122 | for index in range(num_envs): 123 | saved_actions = model.saved_actions[index] 124 | model_rewards = model.rewards[index] 125 | R = torch.zeros(1, 1) 126 | R = Variable(R) 127 | rewards = [] 128 | # compute the reward for each state in the end of a rollout 129 | for r in model_rewards[::-1]: 130 | R = r + args.gamma * R 131 | rewards.insert(0, R) 132 | rewards = torch.tensor(rewards) 133 | if rewards.std() != rewards.std() or len(rewards) == 0: 134 | rewards = rewards - rewards.mean() 135 | else: 136 | rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-3) 137 | 138 | for i, reward in enumerate(rewards): 139 | rewards = rewards + args.gamma**i * model.div[index][i].mean() 140 | 141 | for (log_prob, value), r in zip(saved_actions, rewards): 142 | # reward is the delta param 143 | value += Variable(torch.randn(value.size())) 144 | reward = r - value[0].dot(log_prob[0].exp()).item() 145 | # theta 146 | # need gradient descent - so negative 147 | policy_losses.append(-log_prob * reward) 148 | # https://pytorch.org/docs/master/nn.html#torch.nn.SmoothL1Loss 149 | # feeds a weird difference between value and the reward 150 | value_losses.append(F.smooth_l1_loss(value, torch.tensor([r]))) 151 | 152 | loss = (torch.stack(policy_losses).sum() + \ 153 | 0.5*torch.stack(value_losses).sum() - \ 154 | torch.stack(model.entropies).sum() * 0.0001) / num_envs 155 | 156 | # Debugging 157 | if False: 158 | print(divergence[0].data) 159 | print(loss, 'loss') 160 | print() 161 | # compute gradients 162 | optimizer.zero_grad() 163 | loss.backward() 164 | 165 | nn.utils.clip_grad_norm_(model.parameters(), 30) 166 | 167 | # Debugging 168 | if False: 169 | print('grad') 170 | for i in range(num_envs): 171 | print(i) 172 | print(model.mu_heads[i].weight) 173 | 174 | # train the NN 175 | optimizer.step() 176 | 177 | model.div = [[] for i in range(num_envs)] 178 | model.saved_actions = [[] for i in range(num_envs)] 179 | model.entropies = [] 180 | model.rewards = [[] for i in range(model.num_envs)] 181 | 182 | ml10 = metaworld.MT10() # Construct the benchmark, sampling tasks 183 | envs = [] 184 | for name, env_cls in ml10.train_classes.items(): 185 | env = env_cls() 186 | task = random.choice([task for task in ml10.train_tasks 187 | if task.env_name == name]) 188 | env.set_task(task) 189 | envs.append(env) 190 | 191 | pi = Variable(torch.FloatTensor([math.pi])) 192 | eps = np.finfo(np.float32).eps.item() 193 | SavedAction = namedtuple('SavedAction', ['log_prob', 'value']) 194 | test = False 195 | trained = False 196 | args = parser() 197 | model = Policy() 198 | # learning rate - might be useful to change 199 | optimizer = optim.Adam(model.parameters(), lr=1e-3) 200 | 201 | running_reward = 10 202 | run_reward = np.array([10 for i in range(num_envs)]) 203 | roll_length = np.array([0 for i in range(num_envs)]) 204 | trained_envs = np.array([False for i in range(num_envs)]) 205 | rewardsRec = [[] for i in range(num_envs)] 206 | for i_episode in range(6000): 207 | p = np.random.random() 208 | length = 0 209 | for index, env in enumerate(envs): 210 | # Train each environment simultaneously with the distilled policy 211 | state = env.reset() 212 | r = 0 213 | done = False 214 | for t in range(200): # Don't infinite loop while learning 215 | action = select_action(state, index) 216 | state, reward, done, _ = env.step(Categorical(action).sample()) 217 | r += reward 218 | model.rewards[index].append(reward) 219 | if args.render: 220 | env.render() 221 | if done or t == 199: 222 | print(i_episode, index, r) 223 | rewardsRec[index].append(r) 224 | length += t 225 | roll_length[index] = t 226 | break 227 | np.save('distral_rewardsRec.npy',rewardsRec) 228 | torch.save(model.state_dict(), 'distral_params.pkl') 229 | finish_episode() 230 | -------------------------------------------------------------------------------- /gradsur.py: -------------------------------------------------------------------------------- 1 | import metaworld 2 | import random 3 | import numpy as np 4 | from operator import mul 5 | from functools import reduce 6 | import torch 7 | import torch.autograd as autograd 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.optim as optim 11 | import torch.nn.utils as utils 12 | from torch.autograd import Variable 13 | from torch.distributions.categorical import Categorical 14 | import torch_ac 15 | import pdb 16 | import gym 17 | from itertools import count 18 | from collections import namedtuple 19 | import argparse, math, os 20 | device = torch.device("cuda") 21 | torch.manual_seed(1337) 22 | SavedAction = namedtuple('SavedAction', ['log_prob', 'value']) 23 | 24 | def pc_grad_update(gradient_list): 25 | ''' 26 | PyTorch implementation of PCGrad. 27 | Gradient Surgery for Multi-Task Learning: https://arxiv.org/pdf/2001.06782.pdf 28 | Arguments: 29 | gradient_list (Iterable[Tensor] or Tensor): an iterable of Tensorsthat will 30 | have gradients with respect to parameters for each task. 31 | Returns: 32 | List of gradients with PCGrad applied. 33 | ''' 34 | 35 | assert type(gradient_list) is list 36 | assert len(gradient_list) ! = 0 37 | num_tasks = len(gradient_list) 38 | num_params = len(gradient_list[0]) 39 | np.random.shuffle(gradient_list) 40 | 41 | def flatten_and_store_dims(grad_task): 42 | output = [] 43 | grad_dim = [] 44 | for param_grad in grad_task: # TODO(speedup): convert to map since they are faster 45 | grad_dim.append(tuple(param_grad.shape)) 46 | output.append(torch.flatten(param_grad)) 47 | 48 | return torch.cat(output), grad_dim 49 | 50 | def restore_dims(grad_task, chunk_dims): 51 | ## chunk_dims is a list of tensor shapes 52 | chunk_sizes = [reduce(mul, dims, 1) for dims in chunk_dims] 53 | 54 | grad_chunk = torch.split(grad_task, split_size_or_sections = chunk_sizes) 55 | resized_chunks = [] 56 | for index, grad in enumerate(grad_chunk): # TODO(speedup): convert to map since they are faster 57 | grad = torch.reshape(grad, chunk_dims[index]) 58 | resized_chunks.append(grad) 59 | 60 | return resized_chunks 61 | 62 | def project_gradients(grad_task): 63 | """ 64 | Subtracts projected gradient components for each grad in gradient_list 65 | if it conflicts with input gradient. 66 | Argument: 67 | grad_task (Tensor): A tensor for a gradient 68 | Returns: 69 | Component subtracted gradient 70 | """ 71 | grad_task, grad_dim = flatten_and_store_dims(grad_task) 72 | 73 | for k in range(num_tasks): # TODO(speedup): convert to map since they are faster 74 | conflict_gradient_candidate = gradient_list[k] 75 | # no need to store dims of candidate since we are not changing it in the array 76 | conflict_gradient_candidate, _ = flatten_and_store_dims(conflict_gradient_candidate) 77 | 78 | inner_product = torch.dot(torch.flatten(grad_task), torch.flatten(conflict_gradient_candidate)) 79 | # TODO(speedup): put conflict check condition here so that we aren't calculating norms for non-conflicting gradients 80 | proj_direction = inner_product / torch.norm(conflict_gradient_candidate)**2 81 | 82 | ## sanity check to see if there's any conflicting gradients 83 | # if proj_direction < 0.: 84 | # print('conflict') 85 | # TODO(speedup): This is a cumulative subtraction, move to threaded in-memory map-reduce 86 | grad_task = grad_task - min(proj_direction, 0.) * conflict_gradient_candidate 87 | 88 | # get back grad_task 89 | grad_task = restore_dims(grad_task, grad_dim) 90 | return grad_task 91 | 92 | flattened_grad_task = list(map(project_gradients, gradient_list)) 93 | 94 | yield flattened_grad_task 95 | 96 | class ACModel(nn.Module): 97 | 98 | def __init__(self, tasks = 1): 99 | 100 | super(ACModel, self).__init__() 101 | 102 | self.actor = torch.nn.ModuleList ( [ nn.Sequential( 103 | nn.Linear(12, 128), 104 | nn.ReLU(), 105 | nn.Linear(128, 64), 106 | nn.ReLU(), 107 | nn.Linear(64, 32), 108 | nn.ReLU(), 109 | nn.Linear(32, 16), 110 | nn.ReLU(), 111 | nn.Linear(16, 4) 112 | ) for i in range(1) ] ) 113 | self.value_head = torch.nn.ModuleList ( [nn.Sequential( 114 | nn.Linear(12, 128), 115 | nn.ReLU(), 116 | nn.Linear(128, 64), 117 | nn.ReLU(), 118 | nn.Linear(64, 32), 119 | nn.ReLU(), 120 | nn.Linear(32, 16), 121 | nn.ReLU(), 122 | nn.Linear(16, 1) 123 | ) for i in range(1) ] ) 124 | 125 | self.saved_actions = [[] for i in range(1)] 126 | self.rewards = [[] for i in range(1)] 127 | self.tasks = 1 128 | 129 | def forward(self, x): 130 | tmp = [] 131 | for i in range(1) : 132 | tmp.append( F.softmax(self.actor[i](x)-self.actor[i](x).max())) 133 | state_values = self.value_head[0](x) 134 | return tmp, state_values 135 | 136 | class REINFORCE: 137 | def __init__(self): 138 | self.model = ACModel() 139 | self.optimizer = optim.Adam(self.model.parameters(), lr = 1e-3) 140 | self.model.train() 141 | self.saved_actions = [[]] 142 | self.rewards = [[]] 143 | 144 | def select_action(self, state): 145 | state = torch.tensor(list(state)).float() 146 | probs, state_value = self.model(Variable(state)) 147 | 148 | # Obtain the most probable action for each one of the policies 149 | actions = [] 150 | self.saved_actions[0].append(SavedAction(probs[0].log().dot(probs[0]), state_value)) 151 | 152 | return probs, state_value 153 | 154 | def update_parameters(self, rewards, log_probs, entropies, gamma, multitask = False):# 更新参数 155 | R = torch.zeros(1, 1) 156 | loss = 0 157 | for i in reversed(range(len(rewards))): 158 | R = gamma * R + rewards[i] # 倒序计算累计期望 159 | # loss = loss - (log_probs[i]*(Variable(R).expand_as(log_probs[i])).cuda()).sum() - (0.0001*entropies[i].cuda()).sum() 160 | loss = loss - (log_probs[i]*(Variable(R).expand_as(log_probs[i]))).sum() - (0.0001*entropies[i]).sum() 161 | loss = loss / len(rewards) 162 | if multitask: 163 | loss += lossw2(currindex, env_id, loss, w, B) 164 | self.optimizer.zero_grad() 165 | loss.backward() 166 | utils.clip_grad_norm_(self.model.parameters(), 40) # 梯度裁剪,梯度的最大L2范数 = 40 167 | self.optimizer.step() 168 | 169 | def finish_episode( tasks, alpha, beta, gamma): 170 | optimizer = [ optim.Adam(agents[i].model.parameters(), lr = 3e-2) for i in range(len(envs))] 171 | losses = [] 172 | grad_list = [] 173 | for env_id in range(len(envs)): 174 | ### Calculate loss function according to Equation 1 175 | R = 0 176 | saved_actions = agents[env_id].saved_actions[0] 177 | policy_losses = [] 178 | value_losses = [] 179 | 180 | ## Obtain the discounted rewards backwards 181 | rewards = [] 182 | for r in agents[env_id].rewards[0][::-1]: 183 | R = r + gamma * R 184 | rewards.insert(0, R) 185 | 186 | ## Standardize the rewards to be unit normal (to control the gradient estimator variance) 187 | rewards = torch.Tensor(rewards) 188 | rewards = (rewards - rewards.mean()) / (rewards.std() + np.finfo(np.float32).eps) 189 | for (log_prob, value), r in zip(saved_actions, rewards): 190 | reward = r - value.data[0] 191 | policy_losses.append(-log_prob * reward) 192 | value_losses.append(F.smooth_l1_loss(value, Variable(torch.Tensor([r])))) 193 | optimizer[env_id].zero_grad() 194 | loss = torch.stack(policy_losses).sum() + torch.stack(value_losses).sum() 195 | loss.backward() 196 | optimizer[env_id].step() 197 | 198 | for i in range(1): 199 | del agents[env_id].rewards[i][:] 200 | del agents[env_id].saved_actions[0][:] 201 | if rnd! = 0: 202 | tmp = [] 203 | for p in agents[env_id].model.parameters(): 204 | # Simulate 5 different tasks 205 | tmp.append(p.grad) 206 | grad_list.append(tmp) 207 | 208 | if rnd = 0: 209 | pc_grad_update(grad_list) 210 | 211 | ml10 = metaworld.MT10() # Construct the benchmark, sampling tasks 212 | envs = [] 213 | for name, env_cls in ml10.train_classes.items(): 214 | env = env_cls() 215 | task = random.choice([task for task in ml10.train_tasks 216 | if task.env_name == name]) 217 | env.set_task(task) 218 | envs.append(env) 219 | 220 | for env in envs: 221 | obs = env.reset() # Reset environment 222 | a = env.action_space.sample() # Sample an action 223 | obs, reward, done, info = env.step(a) # Step the environoment with the sampled random action 224 | 225 | rnd = 0 226 | batch_size = 128 227 | alpha = 0.5 228 | beta = 0.5 229 | gamma = 0.999 230 | is_plot = False 231 | num_episodes = 500 232 | max_num_steps_per_episode = 10000 233 | learning_rate = 0.001 234 | rewardsRec = [[] for _ in range(len(envs))] 235 | tasks = len(envs) 236 | agents = [REINFORCE() for i in range(len(envs))] 237 | optimizer = [ optim.Adam(agents[i].model.parameters(), lr = 3e-2) for i in range(len(envs))] 238 | 239 | for rnd in range(10000): 240 | for env_id in range(len(envs)): 241 | rewardRec = [] 242 | for i_episode in range(1): 243 | visualise = True 244 | rewardcnt = 0 245 | observations = envs[env_id].reset() 246 | for t in range(200): 247 | probs, state_value = agents[env_id].select_action(observations) 248 | observations, reward, done, _ = envs[env_id].step(probs[0].detach().numpy()) 249 | agents[env_id].rewards[0].append(reward) 250 | rewardcnt + = reward 251 | if done: 252 | break 253 | rewardRec.append(rewardcnt) 254 | rewardsRec[env_id].append(rewardcnt) 255 | np.save('rewardsRec2_gradsur_meta.npy', rewardsRec) 256 | print(rnd, env_id, rewardcnt) 257 | 258 | finish_episode( tasks, alpha, beta, gamma ) 259 | for env_id in range(len(envs)): 260 | torch.save(agents[env_id].model.state_dict(), str(env_id)+'meta10.pkl') # 只保存网络中的参数 (速度快, 占内存少) 261 | -------------------------------------------------------------------------------- /SAC.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import os 3 | import sys 4 | import yaml 5 | import argparse 6 | from datetime import datetime 7 | from abc import ABC, abstractmethod 8 | import numpy as np 9 | import matplotlib.pyplot as plt 10 | import constopt 11 | from constopt.constraints import LinfBall 12 | from constopt.stochastic import PGD, PGDMadry, FrankWolfe, MomentumFrankWolfe 13 | import torch 14 | from torch.optim import Adam 15 | from torch.autograd import Variable 16 | import torch.nn as nn 17 | from torch.nn import functional as F 18 | from torch.distributions import Categorical 19 | import torch.nn.utils as utils 20 | from torch.utils.tensorboard import SummaryWriter 21 | from scipy.stats import rankdata 22 | from collections import deque 23 | sys.path.insert(0, r'constopt-pytorch/') 24 | 25 | class BaseAgent(ABC): 26 | def __init__(self, env, test_env, log_dir, num_steps=100000, batch_size=16, 27 | memory_size=1000000, gamma=0.99, multi_step=1, 28 | target_entropy_ratio=0.98, start_steps=20000, 29 | update_interval=4, target_update_interval=8000, 30 | use_per=False, num_eval_steps=1000, max_episode_steps=27000, 31 | log_interval=10, eval_interval=1000, cuda=True, seed=0): 32 | 33 | super().__init__() 34 | self.env = env 35 | self.test_env = test_env 36 | 37 | # Set seed. 38 | torch.manual_seed(seed) 39 | np.random.seed(seed) 40 | self.env.seed(seed) 41 | self.test_env.seed(2**31-1-seed) 42 | # torch.backends.cudnn.deterministic = True # It harms a performance. 43 | # torch.backends.cudnn.benchmark = False # It harms a performance. 44 | 45 | self.device = 'cpu' 46 | # LazyMemory efficiently stores FrameStacked states. 47 | if use_per: 48 | beta_steps = (num_steps - start_steps) / update_interval 49 | self.memory = LazyPrioritizedMultiStepMemory( 50 | capacity=memory_size, 51 | state_shape=self.env.observation_space.shape, 52 | device=self.device, gamma=gamma, multi_step=multi_step, 53 | beta_steps=beta_steps) 54 | else: 55 | self.memory = LazyMultiStepMemory( 56 | capacity=memory_size, 57 | state_shape=self.env.observation_space.shape, 58 | device=self.device, gamma=gamma, multi_step=multi_step) 59 | 60 | self.log_dir = log_dir 61 | self.model_dir = os.path.join(log_dir, 'model') 62 | self.summary_dir = os.path.join(log_dir, 'summary') 63 | if not os.path.exists(self.model_dir): 64 | os.makedirs(self.model_dir) 65 | if not os.path.exists(self.summary_dir): 66 | os.makedirs(self.summary_dir) 67 | 68 | self.writer = SummaryWriter(log_dir=self.summary_dir) 69 | self.train_return = RunningMeanStats(log_interval) 70 | 71 | self.steps = 0 72 | self.learning_steps = 0 73 | self.episodes = 0 74 | self.best_eval_score = -np.inf 75 | self.num_steps = num_steps 76 | self.batch_size = 16 77 | self.gamma_n = gamma ** multi_step 78 | self.start_steps = start_steps 79 | self.update_interval = update_interval 80 | self.target_update_interval = target_update_interval 81 | self.use_per = use_per 82 | self.num_eval_steps = num_eval_steps 83 | self.max_episode_steps =200 84 | self.log_interval = log_interval 85 | self.eval_interval = eval_interval 86 | 87 | def run(self, rnd): 88 | self.train_episode(rnd) 89 | 90 | def is_update(self): 91 | return self.steps % self.update_interval == 0 and self.steps >= self.start_steps 92 | 93 | @abstractmethod 94 | def explore(self, state): 95 | pass 96 | 97 | @abstractmethod 98 | def exploit(self, state): 99 | pass 100 | 101 | @abstractmethod 102 | def update_target(self): 103 | pass 104 | 105 | @abstractmethod 106 | def calc_current_q(self, states, actions, rewards, next_states, dones): 107 | pass 108 | 109 | @abstractmethod 110 | def calc_target_q(self, states, actions, rewards, next_states, dones): 111 | pass 112 | 113 | @abstractmethod 114 | def calc_critic_loss(self, batch, weights): 115 | pass 116 | 117 | @abstractmethod 118 | def calc_policy_loss(self, batch, weights): 119 | pass 120 | 121 | @abstractmethod 122 | def calc_entropy_loss(self, entropies, weights): 123 | pass 124 | 125 | def train_episode(self, rnd): 126 | if rnd > 0: 127 | self.policy.load_state_dict(torch.load(str(index)+'policysac_newatari.pth')) 128 | self.online_critic.load_state_dict(torch.load(str(index)+'online_criticsac_newatari.pth')) 129 | self.target_critic.load_state_dict(torch.load(str(index)+ 'target_criticsac_newatari.pth')) 130 | for inner_rnd in range(20): 131 | self.episodes += 1 132 | episode_return = 0. 133 | episode_steps = 0 134 | done = False 135 | state = self.env.reset() 136 | 137 | while (not done) and episode_steps <= 200-1: 138 | if self.start_steps > self.steps: 139 | action = self.env.action_space.sample() 140 | next_state, reward, done, info = self.env.step(action) 141 | else: 142 | action = self.explore(state) 143 | next_state, reward, done, info = self.env.step(action.detach().cpu().numpy()[0][0]) 144 | # Clip reward to [-1.0, 1.0]. 145 | clipped_reward = max(min(reward, 1.0), -1.0) 146 | 147 | # To calculate efficiently, set priority=max_priority here. 148 | self.memory.append(state, action, clipped_reward, next_state, done) 149 | 150 | self.steps += 1 151 | episode_steps += 1 152 | episode_return += reward 153 | state = next_state 154 | 155 | if (inner_rnd+1)%5 == 0: 156 | print('save') 157 | self.learn(inner_rnd) 158 | self.update_target() 159 | self.save_models(os.path.join(self.model_dir, 'final')) 160 | rewardsRec[index].append(episode_return) 161 | np.save('sac_newatari_rewardsRec.npy', rewardsRec) 162 | np.save('sac_newatari_succeessRec.npy', succeessRec) 163 | # We log running mean of training rewards. 164 | self.train_return.append(episode_return) 165 | 166 | if self.episodes % self.log_interval == 0: 167 | self.writer.add_scalar( 168 | 'reward/train', self.train_return.get(), self.steps) 169 | 170 | print('Env: ', index, 'rnd: ', rnd, 'Episode: ', self.episodes, 'Return: ', episode_return) 171 | 172 | def learn(self, inner_rnd): 173 | assert hasattr(self, 'q1_optim') and hasattr(self, 'q2_optim') and hasattr(self, 'policy_optim') and hasattr(self, 'alpha_optim') 174 | 175 | self.learning_steps += 1 176 | 177 | if self.use_per: 178 | batch, weights = self.memory.sample(self.batch_size) 179 | else: 180 | batch = self.memory.sample(self.batch_size) 181 | # Set priority weights to 1 when we don't use PER. 182 | weights = 1. 183 | 184 | q1_loss, q2_loss, errors, mean_q1, mean_q2 = self.calc_critic_loss(batch, weights) 185 | policy_loss, entropies = self.calc_policy_loss(batch, weights) 186 | entropy_loss = self.calc_entropy_loss(entropies, weights) 187 | 188 | update_params(self.q1_optim, q1_loss, inner_rnd, True) 189 | update_params(self.q2_optim, q2_loss, inner_rnd) 190 | update_params(self.policy_optim, policy_loss, inner_rnd) 191 | update_params(self.alpha_optim, entropy_loss, inner_rnd) 192 | 193 | self.alpha = self.log_alpha.exp() 194 | 195 | if self.use_per: 196 | self.memory.update_priority(errors) 197 | 198 | if self.learning_steps % self.log_interval == 0: 199 | self.writer.add_scalar( 200 | 'loss/Q1', q1_loss.detach().item(), 201 | self.learning_steps) 202 | self.writer.add_scalar( 203 | 'loss/Q2', q2_loss.detach().item(), 204 | self.learning_steps) 205 | self.writer.add_scalar( 206 | 'loss/policy', policy_loss.detach().item(), 207 | self.learning_steps) 208 | self.writer.add_scalar( 209 | 'loss/alpha', entropy_loss.detach().item(), 210 | self.learning_steps) 211 | self.writer.add_scalar( 212 | 'stats/alpha', self.alpha.detach().item(), 213 | self.learning_steps) 214 | self.writer.add_scalar( 215 | 'stats/mean_Q1', mean_q1, self.learning_steps) 216 | self.writer.add_scalar( 217 | 'stats/mean_Q2', mean_q2, self.learning_steps) 218 | self.writer.add_scalar( 219 | 'stats/entropy', entropies.detach().mean().item(), 220 | self.learning_steps) 221 | 222 | def evaluate(self): 223 | num_episodes = 0 224 | num_steps = 0 225 | total_return = 0.0 226 | 227 | while True: 228 | state = self.test_env.reset() 229 | episode_steps = 0 230 | episode_return = 0.0 231 | done = False 232 | while (not done) and episode_steps <=200-1: 233 | action = self.exploit(state) 234 | next_state, reward, done, _ = self.test_env.step(action.view(4).detach().numpy()) 235 | num_steps += 1 236 | episode_steps += 1 237 | episode_return += reward 238 | state = next_state 239 | 240 | num_episodes += 1 241 | total_return += episode_return 242 | 243 | if num_steps > self.num_eval_steps: 244 | break 245 | 246 | mean_return = total_return / num_episodes 247 | 248 | if mean_return > self.best_eval_score: 249 | self.best_eval_score = mean_return 250 | self.save_models(os.path.join(self.model_dir, 'best')) 251 | 252 | self.writer.add_scalar( 253 | 'reward/test', mean_return, self.steps) 254 | print('-' * 60) 255 | print(f'Num steps: {self.steps:<5} ' 256 | f'return: {mean_return:<5.1f}') 257 | print('-' * 60) 258 | 259 | @abstractmethod 260 | def save_models(self, save_dir): 261 | if not os.path.exists(save_dir): 262 | os.makedirs(save_dir) 263 | 264 | def __del__(self): 265 | self.env.close() 266 | self.test_env.close() 267 | self.writer.close() 268 | 269 | def initialize_weights_he(m): 270 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): 271 | torch.nn.init.kaiming_uniform_(m.weight) 272 | if m.bias is not None: 273 | torch.nn.init.constant_(m.bias, 0) 274 | 275 | class Flatten(nn.Module): 276 | def forward(self, x): 277 | return x.view(x.size(0), -1) 278 | 279 | class BaseNetwork(nn.Module): 280 | def save(self, path): 281 | torch.save(self.state_dict(), path) 282 | 283 | def load(self, path): 284 | self.load_state_dict(torch.load(path)) 285 | 286 | class QNetwork(BaseNetwork): 287 | 288 | def __init__(self, num_channels, num_actions, shared=False, 289 | dueling_net=False): 290 | super().__init__() 291 | self.image_conv =nn.Sequential(nn.Conv2d(3, 1, (5, 5), (5, 5)), nn.Conv2d(1, 1, (3, 3), (3, 3))) 292 | 293 | if not dueling_net: 294 | self.head = nn.Sequential( 295 | nn.Linear(140, 128), 296 | nn.ReLU(), 297 | nn.Linear(128, 64), 298 | nn.ReLU(), 299 | nn.Linear(64, 32), 300 | nn.ReLU(), 301 | nn.Linear(32, 16), 302 | nn.ReLU(), 303 | nn.Linear(16, envs[0].action_space.n)) 304 | else: 305 | self.a_head = nn.Sequential( 306 | nn.Linear(140, 128), 307 | nn.ReLU(), 308 | nn.Linear(128, 64), 309 | nn.ReLU(), 310 | nn.Linear(64, 32), 311 | nn.ReLU(), 312 | nn.Linear(32, 16), 313 | nn.ReLU(), 314 | nn.Linear(16, envs[0].action_space.n)) 315 | self.v_head = nn.Sequential( 316 | nn.Linear(140, 128), 317 | nn.ReLU(), 318 | nn.Linear(128, 64), 319 | nn.ReLU(), 320 | nn.Linear(64, 32), 321 | nn.ReLU(), 322 | nn.Linear(32, 16), 323 | nn.ReLU(), 324 | nn.Linear(16, 1)) 325 | 326 | self.shared = shared 327 | self.dueling_net = dueling_net 328 | 329 | def forward(self, states): 330 | if len(states.shape) == 3: 331 | states = states.view([1] + list(states.shape)) 332 | else: 333 | states = states.permute(0, 3, 1, 2) 334 | new_x = states 335 | new_x = self.image_conv(new_x) 336 | states = new_x.reshape(states.shape[0], -1) 337 | if not self.dueling_net: 338 | return self.head(states) 339 | else: 340 | a = self.a_head(states) 341 | v = self.v_head(states) 342 | return v + a - a.mean(1, keepdim=True) 343 | 344 | class TwinnedQNetwork(BaseNetwork): 345 | def __init__(self, num_channels, num_actions, shared=False, 346 | dueling_net=False): 347 | super().__init__() 348 | self.Q1 = QNetwork(num_channels, envs[0].action_space.n, shared, dueling_net) 349 | self.Q2 = QNetwork(num_channels, envs[0].action_space.n, shared, dueling_net) 350 | 351 | def forward(self, states): 352 | q1 = self.Q1(states) 353 | q2 = self.Q2(states) 354 | return q1, q2 355 | 356 | class CateoricalPolicy(BaseNetwork): 357 | def __init__(self, num_channels, num_actions, shared=False): 358 | super().__init__() 359 | self.image_conv =nn.Sequential( 360 | nn.Conv2d(3, 1, (5, 5), (5, 5)), 361 | nn.Conv2d(1, 1, (3, 3), (3, 3)) 362 | ) 363 | 364 | self.head = nn.Sequential( 365 | nn.Linear(140, 128), 366 | nn.ReLU(), 367 | nn.Linear(128, 64), 368 | nn.ReLU(), 369 | nn.Linear(64, 32), 370 | nn.ReLU(), 371 | nn.Linear(32, 16), 372 | nn.ReLU(), 373 | nn.Linear(16, envs[0].action_space.n)) 374 | 375 | self.shared = shared 376 | 377 | def act(self, states): 378 | if len(states.shape) == 3: 379 | states=states.view([1]+list(states.shape)) 380 | else: 381 | states=states.permute(0, 3, 1, 2) 382 | new_x = states 383 | new_x = self.image_conv(new_x) 384 | states = new_x.reshape(states.shape[0], -1) 385 | action_logits = self.head(states) 386 | return action_logits 387 | 388 | def sample(self, states): 389 | if len(states.shape) == 3: 390 | states=states.view([1] + list(states.shape)) 391 | else: 392 | states=states.permute(0, 3, 1, 2) 393 | new_x = states 394 | new_x = self.image_conv(new_x) 395 | states = new_x.reshape(states.shape[0], -1) 396 | action_probs = F.softmax(self.head(states), dim=1) 397 | action_dist = Categorical(action_probs) 398 | actions = action_dist.sample().view(-1, 1) 399 | 400 | # Avoid numerical instability. 401 | z = (action_probs == 0.0).float() * 1e-8 402 | log_action_probs = torch.log(action_probs + z) 403 | 404 | return actions, action_probs, log_action_probs 405 | 406 | class SharedSacdAgent(BaseAgent): 407 | def __init__(self, env, test_env, log_dir, num_steps=100000, batch_size=16, 408 | lr=0.0003, memory_size=1000000, gamma=0.99, multi_step=1, 409 | target_entropy_ratio=0.98, start_steps=20000, 410 | update_interval=4, target_update_interval=1000, 411 | use_per=False, dueling_net=False, num_eval_steps=1000, 412 | max_episode_steps=27000, log_interval=10, eval_interval=1000, 413 | cuda=True, seed=0): 414 | super().__init__( 415 | env, test_env, log_dir, num_steps, batch_size, memory_size, gamma, 416 | multi_step, target_entropy_ratio, start_steps, update_interval, 417 | target_update_interval, use_per, num_eval_steps, max_episode_steps, 418 | log_interval, eval_interval, cuda, seed) 419 | 420 | # Define networks. 421 | 422 | self.policy = CateoricalPolicy( 423 | 12, 4, 424 | shared=True).to(self.device) 425 | self.online_critic = TwinnedQNetwork( 426 | 12, 4, 427 | dueling_net=dueling_net, shared=True).to(device=self.device) 428 | self.target_critic = TwinnedQNetwork( 429 | 12, 4, 430 | dueling_net=dueling_net, shared=True).to(device=self.device).eval() 431 | 432 | # Copy parameters of the learning network to the target network. 433 | self.target_critic.load_state_dict(self.online_critic.state_dict()) 434 | 435 | # Disable gradient calculations of the target network. 436 | disable_gradients(self.target_critic) 437 | 438 | self.policy_optim = Adam(self.policy.parameters(), lr=lr) 439 | self.q1_optim = Adam( 440 | list(self.online_critic.Q1.parameters()), lr=lr) 441 | self.q2_optim = Adam(self.online_critic.Q2.parameters(), lr=lr) 442 | 443 | # Target entropy is -log(1/|A|) * ratio (= maximum entropy * ratio). 444 | self.target_entropy = -np.log(1.0 / 4) * target_entropy_ratio 445 | 446 | # We optimize log(alpha), instead of alpha. 447 | self.log_alpha = torch.zeros(1, requires_grad=True, device=self.device) 448 | self.alpha = self.log_alpha.exp() 449 | self.alpha_optim = Adam([self.log_alpha], lr=lr) 450 | 451 | def explore(self, state): 452 | # Act with randomness. 453 | state = torch.ByteTensor( 454 | state[None, ...]).to(self.device).float() / 255. 455 | with torch.no_grad(): 456 | action, _, _ = self.policy.sample(state) 457 | return action 458 | 459 | def exploit(self, state): 460 | # Act without randomness. 461 | state = torch.ByteTensor( 462 | state[None, ...]).to(self.device).float() / 255. 463 | with torch.no_grad(): 464 | action = self.policy.act(state) 465 | return action 466 | 467 | def update_target(self): 468 | self.target_critic.load_state_dict(self.online_critic.state_dict()) 469 | 470 | def calc_current_q(self, states, actions, rewards, next_states, dones): 471 | curr_q1 = self.online_critic.Q1(states).gather(1, actions.long()) 472 | curr_q2 = self.online_critic.Q2( 473 | states.detach()).gather(1, actions.long()) 474 | return curr_q1, curr_q2 475 | 476 | def calc_target_q(self, states, actions, rewards, next_states, dones): 477 | with torch.no_grad(): 478 | _, action_probs, log_action_probs = self.policy.sample(next_states) 479 | next_q1, next_q2 = self.target_critic(next_states) 480 | next_q = (action_probs * ( 481 | torch.min(next_q1, next_q2) - self.alpha * log_action_probs 482 | )).sum(dim=1, keepdim=True) 483 | 484 | assert rewards.shape == next_q.shape 485 | return rewards + (1.0 - dones) * self.gamma_n * next_q 486 | 487 | def calc_critic_loss(self, batch, weights): 488 | curr_q1, curr_q2 = self.calc_current_q(*batch) 489 | target_q = self.calc_target_q(*batch) 490 | 491 | # TD errors for updating priority weights 492 | errors = torch.abs(curr_q1.detach() - target_q) 493 | 494 | # We log means of Q to monitor training. 495 | mean_q1 = curr_q1.detach().mean().item() 496 | mean_q2 = curr_q2.detach().mean().item() 497 | 498 | # Critic loss is mean squared TD errors with priority weights. 499 | q1_loss = torch.mean((curr_q1 - target_q).pow(2) * weights) 500 | q2_loss = torch.mean((curr_q2 - target_q).pow(2) * weights) 501 | 502 | return q1_loss, q2_loss, errors, mean_q1, mean_q2 503 | 504 | def calc_policy_loss(self, batch, weights): 505 | states, actions, rewards, next_states, dones = batch 506 | 507 | # (Log of) probabilities to calculate expectations of Q and entropies. 508 | _, action_probs, log_action_probs = self.policy.sample(states) 509 | with torch.no_grad(): 510 | # Q for every actions to calculate expectations of Q. 511 | q1, q2 = self.online_critic(states) 512 | q = torch.min(q1, q2) 513 | 514 | # Expectations of entropies. 515 | entropies = -torch.sum( 516 | action_probs * log_action_probs, dim=1, keepdim=True) 517 | 518 | # Expectations of Q. 519 | q = torch.sum(torch.min(q1, q2) * action_probs, dim=1, keepdim=True) 520 | 521 | # Policy objective is maximization of (Q + alpha * entropy) with 522 | # priority weights. 523 | policy_loss = (weights * (- q - self.alpha * entropies)).mean() 524 | 525 | return policy_loss, entropies.detach() 526 | 527 | def calc_entropy_loss(self, entropies, weights): 528 | assert not entropies.requires_grad 529 | 530 | # Intuitively, we increse alpha when entropy is less than target 531 | # entropy, vice versa. 532 | entropy_loss = -torch.mean( 533 | self.log_alpha * (self.target_entropy - entropies) 534 | * weights) 535 | return entropy_loss 536 | 537 | def save_models(self, save_dir): 538 | super().save_models(save_dir) 539 | save_dir=os.path.join('logs_sac_newatari', str(index), f'{name}-seed{args.seed}-{time}') 540 | self.policy.save( str(index)+'policysac_newatari.pth') 541 | self.online_critic.save(str(index)+'online_criticsac_newatari.pth') 542 | self.target_critic.save(str(index)+'target_criticsac_newatari.pth') 543 | 544 | class MultiStepBuff: 545 | def __init__(self, maxlen=3): 546 | super(MultiStepBuff, self).__init__() 547 | self.maxlen = int(maxlen) 548 | self.reset() 549 | 550 | def append(self, state, action, reward): 551 | self.states.append(state) 552 | self.actions.append(action) 553 | self.rewards.append(reward) 554 | 555 | def get(self, gamma=0.99): 556 | assert len(self.rewards) > 0 557 | state = self.states.popleft() 558 | action = self.actions.popleft() 559 | reward = self._nstep_return(gamma) 560 | return state, action, reward 561 | 562 | def _nstep_return(self, gamma): 563 | r = np.sum([r * (gamma ** i) for i, r in enumerate(self.rewards)]) 564 | self.rewards.popleft() 565 | return r 566 | 567 | def reset(self): 568 | # Buffer to store n-step transitions. 569 | self.states = deque(maxlen=self.maxlen) 570 | self.actions = deque(maxlen=self.maxlen) 571 | self.rewards = deque(maxlen=self.maxlen) 572 | 573 | def is_empty(self): 574 | return len(self.rewards) == 0 575 | 576 | def is_full(self): 577 | return len(self.rewards) == self.maxlen 578 | 579 | def __len__(self): 580 | return len(self.rewards) 581 | 582 | class LazyMemory(dict): 583 | def __init__(self, capacity, state_shape, device): 584 | super(LazyMemory, self).__init__() 585 | self.capacity = int(capacity) 586 | self.state_shape = state_shape 587 | self.device = device 588 | self.reset() 589 | 590 | def reset(self): 591 | self['state'] = [] 592 | self['next_state'] = [] 593 | 594 | self['action'] = np.empty((self.capacity, 4), dtype=np.int64) 595 | self['reward'] = np.empty((self.capacity, 1), dtype=np.float32) 596 | self['done'] = np.empty((self.capacity, 1), dtype=np.float32) 597 | 598 | self._n = 0 599 | self._p = 0 600 | 601 | def append(self, state, action, reward, next_state, done, 602 | episode_done=None): 603 | self._append(state, action, reward, next_state, done) 604 | 605 | def _append(self, state, action, reward, next_state, done): 606 | self['state'].append(state) 607 | self['next_state'].append(next_state) 608 | self['action'][self._p] = action 609 | self['reward'][self._p] = reward 610 | self['done'][self._p] = done 611 | 612 | self._n = min(self._n + 1, self.capacity) 613 | self._p = (self._p + 1) % self.capacity 614 | 615 | self.truncate() 616 | 617 | def truncate(self): 618 | while len(self['state']) > self.capacity: 619 | del self['state'][0] 620 | del self['next_state'][0] 621 | 622 | def sample(self, batch_size): 623 | indices = np.random.randint(low=0, high=len(self), size=batch_size) 624 | return self._sample(indices, batch_size) 625 | 626 | def _sample(self, indices, batch_size): 627 | bias = -self._p if self._n == self.capacity else 0 628 | 629 | states = np.empty( 630 | (batch_size, *self.state_shape), dtype=np.uint8) 631 | next_states = np.empty( 632 | (batch_size, *self.state_shape), dtype=np.uint8) 633 | 634 | for i, index1 in enumerate(indices): 635 | _index = np.mod(index1+bias, self.capacity) 636 | states[i, ...] = self['state'][_index] 637 | next_states[i, ...] = self['next_state'][_index] 638 | 639 | states = torch.ByteTensor(states).to(self.device).float() / 255. 640 | next_states = torch.ByteTensor( 641 | next_states).to(self.device).float() / 255. 642 | actions = torch.LongTensor(self['action'][indices]).to(self.device) 643 | rewards = torch.FloatTensor(self['reward'][indices]).to(self.device) 644 | dones = torch.FloatTensor(self['done'][indices]).to(self.device) 645 | 646 | return states, actions, rewards, next_states, dones 647 | 648 | def __len__(self): 649 | return self._n 650 | 651 | class LazyMultiStepMemory(LazyMemory): 652 | def __init__(self, capacity, state_shape, device, gamma=0.99, 653 | multi_step=3): 654 | super(LazyMultiStepMemory, self).__init__( 655 | capacity, state_shape, device) 656 | 657 | self.gamma = gamma 658 | self.multi_step = int(multi_step) 659 | if self.multi_step != 1: 660 | self.buff = MultiStepBuff(maxlen=self.multi_step) 661 | 662 | def append(self, state, action, reward, next_state, done): 663 | if self.multi_step != 1: 664 | self.buff.append(state, action, reward) 665 | 666 | if self.buff.is_full(): 667 | state, action, reward = self.buff.get(self.gamma) 668 | self._append(state, action, reward, next_state, done) 669 | 670 | if done: 671 | while not self.buff.is_empty(): 672 | state, action, reward = self.buff.get(self.gamma) 673 | self._append(state, action, reward, next_state, done) 674 | else: 675 | self._append(state, action, reward, next_state, done) 676 | 677 | def update_params(optim, loss, inner_rnd, retain_graph=False): 678 | optim.zero_grad() 679 | w=[] 680 | for key in agents[0].policy.state_dict().keys(): 681 | w.append(torch.cat([agents[j].policy.state_dict()[key].unsqueeze(-1) for j in range(len(envs))])) 682 | pre=loss 683 | with torch.autograd.set_detect_anomaly(True): 684 | pre.backward(retain_graph=retain_graph) 685 | # nn.utils.clip_grad_norm_(model.parameters(), 30) 686 | optim.step() 687 | 688 | def disable_gradients(network): 689 | # Disable calculations of gradients. 690 | for param in network.parameters(): 691 | param.requires_grad = False 692 | 693 | class RunningMeanStats: 694 | def __init__(self, n=10): 695 | self.n = n 696 | self.stats = deque(maxlen=n) 697 | 698 | def append(self, x): 699 | self.stats.append(x) 700 | 701 | def get(self): 702 | return np.mean(self.stats) 703 | 704 | def loss(rloss, w, B, mu=0.2, lamb=[0.01, 0.01, 0.01]): 705 | return torch.tensor([1+mu*(np.linalg.norm(B[t], ord=1)-np.linalg.norm(B[t][t], ord=1)) for t in range(len(envs))]).dot(rloss)+lamb[0]*sum([sum([sum([torch.norm(w[i][t]-sum([B.T[t][j]*w[i][j] for j in range(len(envs))]), p=2)**2]) for i in range(2)]) for t in range(len(envs))]) 706 | 707 | class parser: 708 | def __init__(self): 709 | self.config=os.path.join('metaworld-master/', 'sacd.yaml') 710 | self.shared=True 711 | self.env_id='MsPacmanNoFrameskip-v4' 712 | self.cuda=True 713 | self.seed=0 714 | 715 | args = parser() 716 | with open(args.config) as f: 717 | config = yaml.load(f, Loader=yaml.SafeLoader) 718 | 719 | envs=[] 720 | for env_name in ['YarsRevenge', 'Jamesbond', 'FishingDerby', 'Venture', 721 | 'DoubleDunk', 'Kangaroo', 'IceHockey', 'ChopperCommand', 'Krull', 722 | 'Robotank', 'BankHeist', 'RoadRunner', 'Hero', 'Boxing', 723 | 'Seaquest', 'PrivateEye', 'StarGunner', 'Riverraid', 724 | 'Zaxxon', 'Tennis', 'BattleZone', 725 | 'MontezumaRevenge', 'Frostbite', 'Gravitar', 726 | 'Defender', 'Pitfall', 'Solaris', 'Berzerk', 727 | 'Centipede'][:10]: 728 | env = gym.make(env_name) 729 | envs.append(env) 730 | 731 | rloss = [0.0 for i in range(len(envs))] 732 | rewardsRec = [[] for i in range(len(envs))] 733 | rewardsRec_nor = [[0] for i in range(len(envs))] 734 | succeessRec = [[] for i in range(len(envs))] 735 | 736 | agents = [] 737 | for index in range(len(envs)): 738 | # Create environments. 739 | env = envs[index] 740 | test_env = envs[index] 741 | 742 | # Specify the directory to log. 743 | name = args.config.split('/')[-1].rstrip('.yaml') 744 | if args.shared: 745 | name = 'shared-' + name 746 | time = datetime.now().strftime("%Y%m%d-%H%M") 747 | log_dir = os.path.join('logs_sac_newatari', str(index), f'{name}-seed{args.seed}-{time}') 748 | 749 | # Create the agent. 750 | Agent = SacdAgent if not args.shared else SharedSacdAgent 751 | agent = Agent( 752 | env=env, test_env=test_env, log_dir=log_dir, cuda=args.cuda, 753 | seed=args.seed, **config) 754 | agents.append(agent) 755 | 756 | for i_episode in range(10000): 757 | for index in range(len(envs)): 758 | rnd = i_episode 759 | agents[index].run(rnd) 760 | -------------------------------------------------------------------------------- /soft-module.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import sys 4 | import metaworld 5 | from metaworld.envs.mujoco.sawyer_xyz import * 6 | import random 7 | import time 8 | import numpy as np 9 | import math 10 | import copy 11 | from collections import deque 12 | import gym 13 | from gym import Wrapper 14 | from gym.spaces import Box 15 | import torch 16 | import torch.optim as optim 17 | from torch import nn as nn 18 | import torchrl.algo.utils as atu 19 | from torchrl.utils import get_params 20 | from torchrl.env import get_env 21 | from torchrl.utils import Logger 22 | import torchrl.policies as policies 23 | import torchrl.networks as networks 24 | from torchrl.collector.base import BaseCollector 25 | from torchrl.algo import SAC 26 | from torchrl.algo import TwinSAC 27 | from torchrl.algo import TwinSACQ 28 | from torchrl.algo import MTSAC 29 | from torchrl.collector.para import ParallelCollector 30 | from torchrl.collector.para import AsyncParallelCollector 31 | from torchrl.collector.para.mt import SingleTaskParallelCollectorBase 32 | from torchrl.replay_buffers import BaseReplayBuffer 33 | from torchrl.replay_buffers.shared import SharedBaseReplayBuffer 34 | from torchrl.replay_buffers.shared import AsyncSharedReplayBuffer 35 | from torchrl.env.continuous_wrapper import * 36 | from torchrl.env.get_env import wrap_continuous_env 37 | import matplotlib.pyplot as plt 38 | import constopt 39 | from constopt.constraints import LinfBall 40 | from constopt.stochastic import PGD, PGDMadry, FrankWolfe, MomentumFrankWolfe 41 | import torch 42 | from torch.autograd import Variable 43 | import torch.nn.utils as utils 44 | from scipy.stats import rankdata 45 | os.environ['LD_LIBRARY_PATH'] = '/root/.mujoco/mujoco210/bin:/usr/lib/nvidia' 46 | sys.path.insert(0, './metaworld-master/') 47 | sys.path.append("./") 48 | sys.path.append("../..") 49 | sys.path.insert(0, r'./constopt-pytorch/') 50 | 51 | class SingleWrapper(Wrapper): 52 | def __init__(self, env): 53 | self._env = env 54 | self.action_space = env.action_space 55 | self.observation_space = env.observation_space 56 | self.train_mode = True 57 | def reset(self): 58 | return self._env.reset() 59 | 60 | def seed(self, se): 61 | self._env.seed(se) 62 | 63 | def reset_with_index(self, task_idx): 64 | return self._env.reset() 65 | 66 | def step(self, action): 67 | obs, reward, done, info = self._env.step(action) 68 | return obs, reward, done, info 69 | 70 | def train(self): 71 | self.train_mode = True 72 | 73 | def test(self): 74 | self.train_mode = False 75 | def eval(self): 76 | self.train_mode = False 77 | 78 | def render(self, mode = 'human', **kwargs): 79 | return self._env.render(mode = mode, **kwargs) 80 | 81 | def close(self): 82 | self._env.close() 83 | 84 | class Normalizer(): 85 | def __init__(self, shape, clip = 10.): 86 | self.shape = shape 87 | self._mean = np.zeros(shape) 88 | self._var = np.ones(shape) 89 | self._count = 1e-4 90 | self.clip = clip 91 | self.should_estimate = True 92 | 93 | def stop_update_estimate(self): 94 | self.should_estimate = False 95 | 96 | def update_estimate(self, data): 97 | if not self.should_estimate: 98 | return 99 | if len(data.shape) == self.shape: 100 | data = data[np.newaxis, :] 101 | self._mean, self._var, self._count = update_mean_var_count_moments( 102 | self._mean, self._var, self._count, 103 | np.mean(data, axis = 0), np.var(data, axis = 0), data.shape[0]) 104 | 105 | def inverse(self, raw): 106 | return raw * np.sqrt(self._var) + self._mean 107 | 108 | def inverse_torch(self, raw): 109 | return raw * torch.Tensor(np.sqrt(self._var)).to(raw.device) \ 110 | + torch.Tensor(self._mean).to(raw.device) 111 | 112 | def filt(self, raw): 113 | return np.clip( 114 | (raw - self._mean) / (np.sqrt(self._var) + 1e-4), 115 | -self.clip, self.clip) 116 | 117 | def filt_torch(self, raw): 118 | return torch.clamp( 119 | (raw - torch.Tensor(self._mean).to(raw.device)) / \ 120 | (torch.Tensor(np.sqrt(self._var) + 1e-4).to(raw.device)), 121 | -self.clip, self.clip) 122 | 123 | class RLAlgo(): 124 | """ 125 | Base RL Algorithm Framework 126 | """ 127 | def __init__(self, 128 | env = None, 129 | replay_buffer = None, 130 | collector = None, 131 | logger = None, 132 | continuous = None, 133 | discount = 0.99, 134 | num_epochs = 3000, 135 | epoch_frames = 200, 136 | max_episode_frames = 999, 137 | batch_size = 128, 138 | device = 'cpu', 139 | train_render = False, 140 | eval_episodes = 1, 141 | eval_render = False, 142 | save_interval = 100, 143 | save_dir = None 144 | ): 145 | self.env = env 146 | self.total_frames = 0 147 | self.continuous = isinstance(self.env.action_space, gym.spaces.Box) 148 | 149 | self.replay_buffer = replay_buffer 150 | self.collector = collector 151 | # device specification 152 | self.device = device 153 | 154 | # environment relevant information 155 | self.discount = discount 156 | self.num_epochs = num_epochs 157 | self.epoch_frames = epoch_frames 158 | self.max_episode_frames = max_episode_frames 159 | 160 | self.train_render = train_render 161 | self.eval_render = eval_render 162 | 163 | # training information 164 | self.batch_size = batch_size 165 | self.training_update_num = 0 166 | self.sample_key = None 167 | 168 | # Logger & relevant setting 169 | self.logger = logger 170 | self.episode_rewards = deque(maxlen = 30) 171 | self.training_episode_rewards = deque(maxlen = 30) 172 | self.eval_episodes = eval_episodes 173 | 174 | self.save_interval = save_interval 175 | self.save_dir = save_dir 176 | if not osp.exists( self.save_dir ): 177 | os.mkdir( self.save_dir ) 178 | 179 | self.best_eval = None 180 | 181 | def start_epoch(self): 182 | pass 183 | 184 | def finish_epoch(self): 185 | return {} 186 | 187 | def pretrain(self): 188 | pass 189 | 190 | def update_per_epoch(self): 191 | pass 192 | 193 | def snapshot(self, prefix, epoch): 194 | for name, network in self.snapshot_networks: 195 | model_file_name = "model_{}_{}.pth".format(name, epoch) 196 | model_path = osp.join(prefix, model_file_name) 197 | torch.save(network.state_dict(), model_path) 198 | 199 | def train(self, epoch): 200 | if epoch = = 1: 201 | self.pretrain() 202 | self.total_frames = 0 203 | if hasattr(self, "pretrain_frames"): 204 | self.total_frames = self.pretrain_frames 205 | 206 | self.start_epoch() 207 | 208 | self.current_epoch = epoch 209 | start = time.time() 210 | 211 | self.start_epoch() 212 | 213 | explore_start_time = time.time() 214 | training_epoch_info = self.collector.train_one_epoch() 215 | for reward in training_epoch_info["train_rewards"]: 216 | self.training_episode_rewards.append(reward) 217 | explore_time = time.time() - explore_start_time 218 | 219 | train_start_time = time.time() 220 | loss = self.update_per_epoch() 221 | train_time = time.time() - train_start_time 222 | 223 | finish_epoch_info = self.finish_epoch() 224 | eval_start_time = time.time() 225 | eval_infos = self.collector.eval_one_epoch() 226 | eval_time = time.time() - eval_start_time 227 | 228 | self.total_frames += self.collector.active_worker_nums * self.epoch_frames 229 | 230 | infos = {} 231 | 232 | for reward in eval_infos["eval_rewards"]: 233 | self.episode_rewards.append(reward) 234 | # del eval_infos["eval_rewards"] 235 | 236 | if self.best_eval is None or \ 237 | np.mean(eval_infos["eval_rewards"]) > self.best_eval: 238 | self.best_eval = np.mean(eval_infos["eval_rewards"]) 239 | self.snapshot(self.save_dir, 'best') 240 | del eval_infos["eval_rewards"] 241 | infos["eval_avg_success_rate"] = eval_infos["success"] 242 | infos["Running_Average_Rewards"] = np.mean(self.episode_rewards) 243 | infos["Running_success_rate"] = training_epoch_info["train_success_rate"] 244 | infos["Train_Epoch_Reward"] = training_epoch_info["train_epoch_reward"] 245 | infos["Running_Training_Average_Rewards"] = np.mean( 246 | self.training_episode_rewards) 247 | infos["Explore_Time"] = explore_time 248 | infos["Train___Time"] = train_time 249 | infos["Eval____Time"] = eval_time 250 | infos.update(eval_infos) 251 | infos.update(finish_epoch_info) 252 | 253 | self.logger.add_epoch_info(epoch, self.total_frames, 254 | time.time() - start, infos ) 255 | 256 | if epoch % self.save_interval == 0: 257 | self.snapshot(self.save_dir, epoch) 258 | if epoch = = self.num_epochs-1: 259 | self.snapshot(self.save_dir, "finish") 260 | self.collector.terminate() 261 | return loss 262 | 263 | def update(self, batch): 264 | raise NotImplementedError 265 | 266 | def _update_target_networks(self): 267 | if self.use_soft_update: 268 | for net, target_net in self.target_networks: 269 | atu.soft_update_from_to(net, target_net, self.tau) 270 | else: 271 | if self.training_update_num % self.target_hard_update_period == 0: 272 | for net, target_net in self.target_networks: 273 | atu.copy_model_params_from_to(net, target_net) 274 | 275 | @property 276 | def networks(self): 277 | return [ 278 | ] 279 | 280 | @property 281 | def snapshot_networks(self): 282 | return [ 283 | ] 284 | 285 | @property 286 | def target_networks(self): 287 | return [ 288 | ] 289 | 290 | def to(self, device): 291 | for net in self.networks: 292 | net.to(device) 293 | 294 | class OffRLAlgo(RLAlgo): 295 | """ 296 | Base RL Algorithm Framework 297 | """ 298 | def __init__(self, 299 | 300 | pretrain_epochs = 0, 301 | 302 | min_pool = 0, 303 | 304 | target_hard_update_period = 1000, 305 | use_soft_update = True, 306 | tau = 0.001, 307 | opt_times = 1, 308 | 309 | **kwargs 310 | ): 311 | super(OffRLAlgo, self).__init__(**kwargs) 312 | 313 | # environment relevant information 314 | self.pretrain_epochs = pretrain_epochs 315 | 316 | # target_network update information 317 | self.target_hard_update_period = target_hard_update_period 318 | self.use_soft_update = use_soft_update 319 | self.tau = tau 320 | 321 | # training information 322 | self.opt_times = opt_times 323 | self.min_pool = min_pool 324 | 325 | self.sample_key = [ "obs", "next_obs", "acts", "rewards", "terminals" ] 326 | 327 | def update_per_timestep(self): 328 | if self.replay_buffer.num_steps_can_sample() > max( self.min_pool, self.batch_size ): 329 | for _ in range( self.opt_times ): 330 | batch = self.replay_buffer.random_batch(self.batch_size, self.sample_key) 331 | infos = self.update( batch ) 332 | self.logger.add_update_info( infos ) 333 | 334 | def update_per_epoch(self): 335 | loss = [] 336 | for _ in range( self.opt_times ): 337 | batch = self.replay_buffer.random_batch(self.batch_size, self.sample_key) 338 | infos = self.update( batch ) 339 | loss.append(infos['Training/policy_loss']) 340 | self.logger.add_update_info( infos ) 341 | return np.mean(loss) 342 | 343 | def pretrain(self): 344 | total_frames = 0 345 | self.pretrain_epochs * self.collector.worker_nums * self.epoch_frames 346 | 347 | for pretrain_epoch in range( self.pretrain_epochs ): 348 | 349 | start = time.time() 350 | 351 | self.start_epoch() 352 | 353 | training_epoch_info = self.collector.train_one_epoch() 354 | for reward in training_epoch_info["train_rewards"]: 355 | self.training_episode_rewards.append(reward) 356 | 357 | finish_epoch_info = self.finish_epoch() 358 | 359 | total_frames += self.collector.active_worker_nums * self.epoch_frames 360 | 361 | infos = {} 362 | 363 | infos["Train_Epoch_Reward"] = training_epoch_info["train_epoch_reward"] 364 | infos["Running_Training_Average_Rewards"] = np.mean(self.training_episode_rewards) 365 | infos.update(finish_epoch_info) 366 | 367 | self.logger.add_epoch_info(pretrain_epoch, total_frames, time.time() - start, infos, csv_write = False ) 368 | 369 | self.pretrain_frames = total_frames 370 | 371 | self.logger.log("Finished Pretrain") 372 | 373 | class SAC(OffRLAlgo): 374 | """ 375 | SAC 376 | """ 377 | def __init__( 378 | self, 379 | pf, vf, qf, 380 | plr, vlr, qlr, 381 | optimizer_class = optim.Adam, 382 | 383 | policy_std_reg_weight = 1e-3, 384 | policy_mean_reg_weight = 1e-3, 385 | 386 | reparameterization = True, 387 | automatic_entropy_tuning = True, 388 | target_entropy = None, 389 | **kwargs 390 | ): 391 | super(SAC, self).__init__(**kwargs) 392 | self.pf = pf 393 | self.qf = qf 394 | self.vf = vf 395 | self.target_vf = copy.deepcopy(vf) 396 | self.to(self.device) 397 | 398 | self.plr = plr 399 | self.vlr = vlr 400 | self.qlr = qlr 401 | 402 | self.qf_optimizer = optimizer_class( 403 | self.qf.parameters(), 404 | lr = self.qlr, 405 | ) 406 | 407 | self.vf_optimizer = optimizer_class( 408 | self.vf.parameters(), 409 | lr = self.vlr, 410 | ) 411 | 412 | self.pf_optimizer = optimizer_class( 413 | self.pf.parameters(), 414 | lr = self.plr, 415 | ) 416 | 417 | self.automatic_entropy_tuning = automatic_entropy_tuning 418 | if self.automatic_entropy_tuning: 419 | if target_entropy: 420 | self.target_entropy = target_entropy 421 | else: 422 | self.target_entropy = -np.prod(self.env.action_space.shape).item() # from rlkit 423 | self.log_alpha = torch.zeros(1).to(self.device) 424 | self.log_alpha.requires_grad_() 425 | self.alpha_optimizer = optimizer_class( 426 | [self.log_alpha], 427 | lr = self.plr, 428 | ) 429 | 430 | self.qf_criterion = nn.MSELoss() 431 | self.vf_criterion = nn.MSELoss() 432 | 433 | self.policy_std_reg_weight = policy_std_reg_weight 434 | self.policy_mean_reg_weight = policy_mean_reg_weight 435 | 436 | self.reparameterization = reparameterization 437 | 438 | def update(self, batch): 439 | self.training_update_num += 1 440 | 441 | obs = batch['obs'] 442 | actions = batch['acts'] 443 | next_obs = batch['next_obs'] 444 | rewards = batch['rewards'] 445 | terminals = batch['terminals'] 446 | 447 | rewards = torch.Tensor(rewards).to( self.device ) 448 | terminals = torch.Tensor(terminals).to( self.device ) 449 | obs = torch.Tensor(obs).to( self.device ) 450 | actions = torch.Tensor(actions).to( self.device ) 451 | next_obs = torch.Tensor(next_obs).to( self.device ) 452 | 453 | """ 454 | Policy operations. 455 | """ 456 | sample_info = self.pf.explore(obs, return_log_probs = True ) 457 | 458 | mean = sample_info["mean"] 459 | log_std = sample_info["log_std"] 460 | new_actions = sample_info["action"] 461 | log_probs = sample_info["log_prob"] 462 | ent = sample_info["ent"] 463 | 464 | q_pred = self.qf([obs, actions]) 465 | v_pred = self.vf(obs) 466 | 467 | if self.automatic_entropy_tuning: 468 | """ 469 | Alpha Loss 470 | """ 471 | alpha_loss = -(self.log_alpha * (log_probs + self.target_entropy).detach()).mean() 472 | self.alpha_optimizer.zero_grad() 473 | alpha_loss.backward() 474 | self.alpha_optimizer.step() 475 | alpha = self.log_alpha.exp() 476 | else: 477 | alpha = 1 478 | alpha_loss = 0 479 | 480 | """ 481 | QF Loss 482 | """ 483 | target_v_values = self.target_vf(next_obs) 484 | q_target = rewards + (1. - terminals) * self.discount * target_v_values 485 | qf_loss = self.qf_criterion( q_pred, q_target.detach()) 486 | 487 | """ 488 | VF Loss 489 | """ 490 | q_new_actions = self.qf([obs, new_actions]) 491 | v_target = q_new_actions - alpha * log_probs 492 | vf_loss = self.vf_criterion( v_pred, v_target.detach()) 493 | 494 | """ 495 | Policy Loss 496 | """ 497 | if not self.reparameterization: 498 | log_policy_target = q_new_actions - v_pred 499 | policy_loss = ( 500 | log_probs * ( alpha * log_probs - log_policy_target).detach() 501 | ).mean() 502 | else: 503 | policy_loss = ( alpha * log_probs - q_new_actions).mean() 504 | 505 | std_reg_loss = self.policy_std_reg_weight * (log_std**2).mean() 506 | mean_reg_loss = self.policy_mean_reg_weight * (mean**2).mean() 507 | 508 | policy_loss += std_reg_loss + mean_reg_loss 509 | 510 | """ 511 | Update Networks 512 | """ 513 | self.pf_optimizer.zero_grad() 514 | 515 | w = [] 516 | for key in pfs[0].state_dict().keys(): 517 | w.append(torch.cat([pfs[j].state_dict()[key].unsqueeze(0) for j in range(len(envs))])) 518 | 519 | rloss[index] = policy_loss.clone().detach().item() 520 | 521 | if multitask: 522 | pre = rloss[index] + lossw(currindex, index, rloss, w, B)/10 523 | else: 524 | pre = rloss[index] 525 | # compute gradients 526 | 527 | pre.backward() 528 | 529 | # train the NN 530 | 531 | self.pf_optimizer.step() 532 | rloss[index] = rloss[index].detach().item() 533 | 534 | self.qf_optimizer.zero_grad() 535 | qf_loss.backward() 536 | self.qf_optimizer.step() 537 | 538 | self.vf_optimizer.zero_grad() 539 | vf_loss.backward() 540 | self.vf_optimizer.step() 541 | 542 | self._update_target_networks() 543 | 544 | # Information For Logger 545 | info = {} 546 | info['Reward_Mean'] = rewards.mean().item() 547 | 548 | if self.automatic_entropy_tuning: 549 | info["Alpha"] = alpha.item() 550 | info["Alpha_loss"] = alpha_loss.item() 551 | info['Training/policy_loss'] = policy_loss.item() 552 | info['Training/vf_loss'] = vf_loss.item() 553 | info['Training/qf_loss'] = qf_loss.item() 554 | 555 | info['log_std/mean'] = log_std.mean().item() 556 | info['log_std/std'] = log_std.std().item() 557 | info['log_std/max'] = log_std.max().item() 558 | info['log_std/min'] = log_std.min().item() 559 | 560 | info['log_probs/mean'] = log_std.mean().item() 561 | info['log_probs/std'] = log_std.std().item() 562 | info['log_probs/max'] = log_std.max().item() 563 | info['log_probs/min'] = log_std.min().item() 564 | 565 | info['mean/mean'] = mean.mean().item() 566 | info['mean/std'] = mean.std().item() 567 | info['mean/max'] = mean.max().item() 568 | info['mean/min'] = mean.min().item() 569 | 570 | return info 571 | 572 | @property 573 | def networks(self): 574 | return [ 575 | self.pf, 576 | self.qf, 577 | self.vf, 578 | self.target_vf 579 | ] 580 | 581 | @property 582 | def snapshot_networks(self): 583 | return [ 584 | ["pf", self.pf], 585 | ["qf", self.qf], 586 | ["vf", self.vf] 587 | ] 588 | 589 | @property 590 | def target_networks(self): 591 | return [ 592 | ( self.vf, self.target_vf ) 593 | ] 594 | 595 | class parser: 596 | def __init__(self): 597 | self.config = 'config/sac_ant.json' 598 | self.id = 'mt10' 599 | self.worker_nums = 10 600 | self.eval_worker_nums = 10 601 | self.seed = 20 602 | self.vec_env_nums = 1 603 | self.save_dir = './save/sac_ant' 604 | self.log_dir = './log/sac_ant' 605 | self.no_cuda = True 606 | self.overwrite = True 607 | self.device = 'cpu' 608 | self.cuda = False 609 | 610 | ml10 = metaworld.MT10() # Construct the benchmark, sampling tasks 611 | envs = [] 612 | for name, env_cls in ml10.train_classes.items(): 613 | env = env_cls() 614 | task = random.choice([task for task in ml10.train_tasks 615 | if task.env_name == name]) 616 | env.set_task(task) 617 | envs.append(env) 618 | 619 | for env in envs: 620 | obs = env.reset() # Reset environment 621 | a = env.action_space.sample() # Sample an action 622 | obs, reward, done, info = env.step(a) # Step the environoment with the sampled random action 623 | 624 | text = """{ 625 | "env_name" : "mt10", 626 | "env":{ 627 | "reward_scale":1, 628 | "obs_norm":false 629 | }, 630 | "meta_env":{ 631 | "obs_type": "with_goal_and_id" 632 | }, 633 | "replay_buffer":{ 634 | "size": 1e6 635 | }, 636 | "net":{ 637 | "hidden_shapes": [128, 64, 32, 16], 638 | "append_hidden_shapes":[] 639 | }, 640 | "general_setting": { 641 | "discount" : 0.99, 642 | "pretrain_epochs" : 20, 643 | "num_epochs" : 7500, 644 | "epoch_frames" : 200, 645 | "max_episode_frames" : 200, 646 | 647 | "batch_size" : 1280, 648 | "min_pool" : 10000, 649 | 650 | "target_hard_update_period" : 1000, 651 | "use_soft_update" : true, 652 | "tau" : 0.005, 653 | "opt_times" : 200, 654 | 655 | "eval_episodes" : 3 656 | }, 657 | "sac":{ 658 | } 659 | }""" 660 | 661 | !mkdir config 662 | with open('config/sac_ant.json', 'w') as f: 663 | f.write(text) 664 | 665 | args = parser() 666 | params = get_params(args.config) 667 | device = torch.device( 668 | "cuda:{}".format(args.device) if args.cuda else "cpu") 669 | if args.cuda: 670 | torch.cuda.manual_seed_all(args.seed) 671 | torch.backends.cudnn.benchmark = False 672 | torch.backends.cudnn.deterministic = True 673 | 674 | pfs = [] 675 | qf1s = [] 676 | vfs = [] 677 | agents = [] 678 | epochs = [1 for i in range(len(envs))] 679 | env.seed(args.seed) 680 | torch.manual_seed(args.seed) 681 | np.random.seed(args.seed) 682 | random.seed(args.seed) 683 | normalizer = Normalizer(env.observation_space.shape) 684 | buffer_param = params['replay_buffer'] 685 | experiment_name = os.path.split( 686 | os.path.splitext(args.config)[0])[-1] if args.id is None \ 687 | else args.id 688 | logger = Logger( 689 | experiment_name, params['env_name'], args.seed, params, args.log_dir) 690 | 691 | for index in range(len(envs)): 692 | env = SingleWrapper(envs[index]) 693 | params = get_params(args.config) 694 | params['general_setting']['logger'] = Logger( 695 | 'mt10', str(index), args.seed, params, './log/mt10_' + str(index) + '/') 696 | params['env_name'] = str(index) 697 | params['general_setting']['env'] = env 698 | 699 | replay_buffer = BaseReplayBuffer( 700 | max_replay_buffer_size = int(buffer_param['size']) 701 | ) 702 | params['general_setting']['replay_buffer'] = replay_buffer 703 | 704 | params['general_setting']['device'] = device 705 | 706 | params['net']['base_type'] = networks.MLPBase 707 | params['net']['activation_func'] = torch.nn.ReLU 708 | 709 | pf = policies.GuassianContPolicy( 710 | input_shape = env.observation_space.shape[0], 711 | output_shape = 2 * env.action_space.shape[0], 712 | **params['net'], 713 | **params['sac']) 714 | 715 | qf1 = networks.QNet( 716 | input_shape = env.observation_space.shape[0] + env.action_space.shape[0], 717 | output_shape = 1, 718 | **params['net']) 719 | 720 | vf = networks.Net( 721 | input_shape = env.observation_space.shape, 722 | output_shape = 1, 723 | **params['net'] 724 | ) 725 | pfs.append(pf) 726 | qf1s.append(qf1) 727 | vfs.append(vf) 728 | params['general_setting']['collector'] = BaseCollector( 729 | env = env, pf = pf, 730 | replay_buffer = replay_buffer, device = device, 731 | train_render = False 732 | ) 733 | params['general_setting']['save_dir'] = osp.join( 734 | './log/', "model10_" + str(index)) 735 | agent = SAC( 736 | pf = pf, 737 | qf = qf1, plr = 3e-4, vlr = 3e-4, qlr = 3e-4, 738 | vf = vf, 739 | **params["sac"], 740 | **params["general_setting"] 741 | ) 742 | agents.append(agent) 743 | 744 | # differentiable ranking loss 745 | def pss(x, points): 746 | def pss0(x, i): 747 | return torch.tanh(200*torch.tensor(x-i))/2 + 0.5 748 | return len(points)-sum([pss0(x, i) for i in points]) 749 | 750 | def losst(currindex, t, rloss, w, B, mu = 0.2, lamb = [0.01, 0.01, 0.01], U = [13], pi = list(range(len(envs)))): 751 | new_rloss = [i for i in rloss] 752 | new_rloss[t] = new_rloss[t] + 1 753 | rlossRank = 1 + len(envs) - rankdata(new_rloss, method = 'min') 754 | points = B[t] 755 | sim = [sum(nn.CosineSimilarity()(pfs[t].state_dict()['last.weight'].view(-1, 1), pfs[i].state_dict()['last.weight'].view(-1, 1))) for i in range(len(envs))] 756 | sim[t] = sim[t] + 100 757 | rlossRank_renew = 1 + len(envs) - rankdata(sim, method = 'min') 758 | term0 = (1 + mu*sum([torch.norm(torch.tensor(B[t][i]), p = 1)for i in set(list(range(len(envs))))-set([t])]))*rloss[t] 759 | term1 = sum([sum([sum([torch.norm(w[i][s]-sum([B[pi[j]][s]*w[i][pi[j]] for j in range(currindex-1)])-B[t][s]*w[i][t],p=2)**2]) for i in range(len(pfs[0].state_dict().keys()))]) for s in U]) 760 | term2 = torch.norm(torch.tensor(priors[current])-torch.tensor([pss(torch.tensor(i-0.01),points) for i in points]))**2 761 | term3 = torch.norm(torch.tensor(rlossRank)-torch.tensor([pss(torch.tensor(i-0.01),points) for i in points]))**2 762 | term4 = torch.norm(torch.tensor(rlossRank2)-torch.tensor([pss(torch.tensor(i-0.01),points) for i in points]))**2 763 | terms_history[0].append(term0.detach().numpy()) 764 | terms_history[1].append(term1) 765 | terms_history[2].append(term2.detach().numpy()) 766 | terms_history[3].append(term3.detach().numpy()) 767 | terms_history[4].append(term4.detach().numpy()) 768 | if len(terms_history[0])<=1: 769 | return term0+term1+term2+term3+term4 770 | else: 771 | return (1/(np.array(terms_history[0]).std())**2)*term0+(1/(np.array(terms_history[1]).std())**2)*term1+(1/(np.array(terms_history[2]).std())**2)*term2+(1/(np.array(terms_history[3]).std())**2)*term3+(1/(np.array(terms_history[4]).std())**2)*term4+np.log((np.array(terms_history[0]).std())*(np.array(terms_history[1]).std())*(np.array(terms_history[2]).std())*(np.array(terms_history[3]).std())*(np.array(terms_history[4]).std())) 772 | 773 | def lossb(currindex, t, rloss, w, B, mu = 0.2, lamb = [0.01, 0.01, 0.01], U = [13], pi = list(range(len(envs)))): 774 | new_rloss = [i for i in rloss] 775 | new_rloss[t] = new_rloss[t] + 1 776 | rlossRank = 1 + len(envs) - rankdata(new_rloss, method = 'min') 777 | points = B[t] 778 | sim = [sum(nn.CosineSimilarity()(pfs[t].state_dict()['last.weight'].view(-1, 1), pfs[i].state_dict()['last.weight'].view(-1, 1))) for i in range(len(envs))] 779 | sim[t] = sim[t] + 100 780 | rlossRank_renew = 1 + len(envs) - rankdata(sim, method = 'min') 781 | term0 = (1 + mu*sum([torch.norm(B[t][i], p = 1)for i in set(list(range(len(envs))))-set([t])]))*rloss[t] 782 | term1 = sum([sum([sum([torch.norm(w[i][s]-sum([B[pi[j]][s]*w[i][pi[j]] for j in range(currindex-1)])-B[t][s]*w[i][t],p=2)**2]) for i in range(len(pfs[0].state_dict().keys()))]) for s in U]) 783 | term2 = torch.norm(torch.tensor(priors[current])-torch.tensor([pss(torch.tensor(i-0.01),points) for i in points]))**2 784 | term3 = torch.norm(torch.tensor(rlossRank)-torch.tensor([pss(torch.tensor(i-0.01),points) for i in points]))**2 785 | term4 = torch.norm(torch.tensor(rlossRank2)-torch.tensor([pss(torch.tensor(i-0.01),points) for i in points]))**2 786 | if len(terms_history[0])<=1: 787 | return term0+term1+term2+term3+term4 788 | else: 789 | return (1/(np.array(terms_history[0]).std())**2)*term0+(1/(np.array(terms_history[1]).std())**2)*term1+(1/(np.array(terms_history[2]).std())**2)*term2+(1/(np.array(terms_history[3]).std())**2)*term3+(1/(np.array(terms_history[4]).std())**2)*term4+np.log((np.array(terms_history[0]).std())*(np.array(terms_history[1]).std())*(np.array(terms_history[2]).std())*(np.array(terms_history[3]).std())*(np.array(terms_history[4]).std())) 790 | 791 | def lossw(currindex, t, rloss, w, B, mu = 0.2, lamb = [0.01, 0.01, 0.01], U = [13], pi = list(range(len(envs)))): 792 | new_rloss = [i for i in rloss] 793 | new_rloss[t] = new_rloss[t] + 1 794 | rlossRank = 1 + len(envs) - rankdata(new_rloss, method = 'min') 795 | points = B[t] 796 | term0 = (1 + mu*sum([torch.norm(torch.tensor(B[t][i]), p=1)for i in set(list(range(len(envs))))-set([t])]))*rloss[t] 797 | sim = [sum(nn.CosineSimilarity()(pfs[t].state_dict()['last.weight'].view(-1, 1), pfs[i].state_dict()['last.weight'].view(-1, 1))) for i in range(len(envs))] 798 | sim[t] = sim[t] + 100 799 | rlossRank_renew = 1 + len(envs) - rankdata(sim, method = 'min') 800 | 801 | term1 = sum([sum([torch.norm(w[i][t]-sum([B[pi[j]][t]*w[i][pi[j]] for j in range(currindex-1)]),p=2)**2]) for i in range(len(pfs[0].state_dict().keys()))]) 802 | term2 = torch.norm(torch.tensor(priors[current])-torch.tensor([pss(torch.tensor(i-0.01),points) for i in points]))**2 803 | term3 = torch.norm(torch.tensor(rlossRank)-torch.tensor([pss(torch.tensor(i-0.01),points) for i in points]))**2 804 | term4 = torch.norm(torch.tensor(rlossRank2)-torch.tensor([pss(torch.tensor(i-0.01),points) for i in points]))**2 805 | if len(terms_history[0])<=1: 806 | return term0+term1+term3+term4 807 | else: 808 | return (1/(np.array(terms_history[0]).std())**2)*term0+(1/(np.array(terms_history[1]).std())**2)*term1+(1/(np.array(terms_history[2]).std())**2)*term2+(1/(np.array(terms_history[3]).std())**2)*term3+(1/(np.array(terms_history[4]).std())**2)*term4+np.log((np.array(terms_history[0]).std())*(np.array(terms_history[1]).std())*(np.array(terms_history[2]).std())*(np.array(terms_history[3]).std())*(np.array(terms_history[4]).std())) 809 | 810 | # FrankWolfe 811 | OPTIMIZER_CLASSES = [FrankWolfe] 812 | radius = 0.05 813 | 814 | def setup_problem(make_nonconvex = False): 815 | radius2 = radius 816 | loss_func = lossb 817 | constraint = LinfBall(radius2) 818 | 819 | return loss_func, constraint 820 | 821 | def optimize(loss_func, constraint, optimizer_class, iterations = 100): 822 | for i in range(len(envs)): 823 | if i! = t: 824 | B[t][i] = torch.tensor(B[t][i], requires_grad = True) 825 | optimizer = [optimizer_class([B[t][i]], constraint) for i in set(list(range(len(envs))))-set([t])] 826 | iterates = [[B[t][i].data if i! = t else B[t][i] for i in range(len(envs))]] 827 | losses = [] 828 | # Use Madry's heuristic for step size 829 | step_size = { 830 | FrankWolfe.name: None, 831 | MomentumFrankWolfe.name: None, 832 | PGD.name: 2.5 * constraint.alpha / iterations * 2., 833 | PGDMadry.name: 2.5 * constraint.alpha / iterations 834 | } 835 | for _ in range(iterations): 836 | for i in range(len(envs)-1): 837 | optimizer[i].zero_grad() 838 | loss = loss_func(currindex, t, rloss, w, B, U = list(set(U)-set(list([t])))) 839 | loss.backward(retain_graph = True) 840 | for i in range(len(envs)-1): 841 | optimizer[i].step(step_size[optimizer[i].name]) 842 | for i in set(list(range(len(envs))))-set([t]): 843 | B[t][i].data.clamp_(0, 100) 844 | losses.append(loss) 845 | iterates.append([B[t][i].data if i! = t else B[t][i] for i in range(len(envs))]) 846 | loss = loss_func(currindex, t, rloss, w, B, U = list(set(U)-set(list([t])))).detach() 847 | losses.append(loss) 848 | B[t] = [B[t][i].data if i! = t else B[t][i] for i in range(len(envs))] 849 | return losses, iterates 850 | 851 | multitask = False 852 | rloss = [0.0 for i in range(len(envs))] 853 | rewardsRec = [[] for i in range(len(envs))] 854 | succeessRec = [[] for i in range(len(envs))] 855 | 856 | for i_episode in range(10000): 857 | for index, env in enumerate(envs): 858 | agents[index].train(epochs[index]) 859 | epochs[index] += 1 860 | -------------------------------------------------------------------------------- /CAMRL.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from scipy.stats import rankdata 4 | import metaworld 5 | from metaworld.envs.mujoco.sawyer_xyz import * 6 | import random 7 | import time 8 | import os.path as osp 9 | import numpy as np 10 | import torch 11 | from torch.autograd import Variable 12 | import torch.nn.utils as utils 13 | from torchrl.utils import get_params 14 | from torchrl.env import get_env 15 | from torchrl.utils import Logger 16 | import torchrl.policies as policies 17 | import torchrl.networks as networks 18 | from torchrl.collector.base import BaseCollector 19 | from torchrl.algo import SAC 20 | from torchrl.algo import TwinSAC 21 | from torchrl.algo import TwinSACQ 22 | from torchrl.algo import MTSAC 23 | import torchrl.algo.utils as atu 24 | from torchrl.collector.para import ParallelCollector 25 | from torchrl.collector.para import AsyncParallelCollector 26 | from torchrl.collector.para.mt import SingleTaskParallelCollectorBase 27 | from torchrl.replay_buffers import BaseReplayBuffer 28 | from torchrl.replay_buffers.shared import SharedBaseReplayBuffer 29 | from torchrl.replay_buffers.shared import AsyncSharedReplayBuffer 30 | from torchrl.env.continuous_wrapper import * 31 | from torchrl.env.get_env import wrap_continuous_env 32 | import gym 33 | from gym import Wrapper 34 | from gym.spaces import Box 35 | import copy 36 | from collections import deque 37 | import matplotlib.pyplot as plt 38 | import constopt 39 | from constopt.constraints import LinfBall 40 | from constopt.stochastic import PGD, PGDMadry, FrankWolfe, MomentumFrankWolfe 41 | 42 | sys.path.insert(0,r'./constopt-pytorch/') 43 | os.environ['LD_LIBRARY_PATH']='/root/.mujoco/mujoco210/bin:/usr/lib/nvidia' 44 | sys.path.insert(0,'./metaworld-master/') 45 | sys.path.append("./") 46 | sys.path.append("../..") 47 | 48 | class SingleWrapper(Wrapper): 49 | def __init__(self, env): 50 | self._env = env 51 | self.action_space = env.action_space 52 | self.observation_space = env.observation_space 53 | self.train_mode = True 54 | 55 | def reset(self): 56 | return self._env.reset() 57 | 58 | def seed(self, se): 59 | self._env.seed(se) 60 | 61 | def reset_with_index(self, task_idx): 62 | return self._env.reset() 63 | 64 | def step(self, action): 65 | obs, reward, done, info = self._env.step(action) 66 | return obs, reward, done, info 67 | 68 | def train(self): 69 | self.train_mode = True 70 | 71 | def test(self): 72 | self.train_mode = False 73 | def eval(self): 74 | self.train_mode = False 75 | 76 | def render(self, mode='human', **kwargs): 77 | return self._env.render(mode=mode, **kwargs) 78 | 79 | def close(self): 80 | self._env.close() 81 | 82 | class Normalizer(): 83 | def __init__(self, shape, clip=10.): 84 | self.shape = shape 85 | self._mean = np.zeros(shape) 86 | self._var = np.ones(shape) 87 | self._count = 1e-4 88 | self.clip = clip 89 | self.should_estimate = True 90 | 91 | def stop_update_estimate(self): 92 | self.should_estimate = False 93 | 94 | def update_estimate(self, data): 95 | if not self.should_estimate: 96 | return 97 | if len(data.shape) == self.shape: 98 | data = data[np.newaxis, :] 99 | self._mean, self._var, self._count = update_mean_var_count_moments( 100 | self._mean, self._var, self._count, 101 | np.mean(data, axis=0), np.var(data, axis=0), data.shape[0]) 102 | 103 | def inverse(self, raw): 104 | return raw * np.sqrt(self._var) + self._mean 105 | 106 | def inverse_torch(self, raw): 107 | return raw * torch.Tensor(np.sqrt(self._var)).to(raw.device) \ 108 | + torch.Tensor(self._mean).to(raw.device) 109 | 110 | def filt(self, raw): 111 | return np.clip( 112 | (raw - self._mean) / (np.sqrt(self._var) + 1e-4), 113 | -self.clip, self.clip) 114 | 115 | def filt_torch(self, raw): 116 | return torch.clamp( 117 | (raw - torch.Tensor(self._mean).to(raw.device)) / \ 118 | (torch.Tensor(np.sqrt(self._var) + 1e-4).to(raw.device)), 119 | -self.clip, self.clip) 120 | 121 | class RLAlgo(): 122 | """ 123 | Base RL Algorithm Framework 124 | """ 125 | def __init__(self, 126 | env = None, 127 | replay_buffer = None, 128 | collector = None, 129 | logger = None, 130 | continuous = None, 131 | discount=0.99, 132 | num_epochs = 3000, 133 | epoch_frames = 200, 134 | max_episode_frames = 999, 135 | batch_size = 128, 136 | device = 'cpu', 137 | train_render = False, 138 | eval_episodes = 1, 139 | eval_render = False, 140 | save_interval = 100, 141 | save_dir = None 142 | ): 143 | 144 | self.env = env 145 | self.total_frames = 0 146 | self.continuous = isinstance(self.env.action_space, gym.spaces.Box) 147 | 148 | self.replay_buffer = replay_buffer 149 | self.collector = collector 150 | # device specification 151 | self.device = device 152 | 153 | # environment relevant information 154 | self.discount = discount 155 | self.num_epochs = num_epochs 156 | self.epoch_frames = epoch_frames 157 | self.max_episode_frames = max_episode_frames 158 | 159 | self.train_render = train_render 160 | self.eval_render = eval_render 161 | 162 | # training information 163 | self.batch_size = batch_size 164 | self.training_update_num = 0 165 | self.sample_key = None 166 | 167 | # Logger & relevant setting 168 | self.logger = logger 169 | 170 | 171 | self.episode_rewards = deque(maxlen=30) 172 | self.training_episode_rewards = deque(maxlen=30) 173 | self.eval_episodes = eval_episodes 174 | 175 | self.save_interval = save_interval 176 | self.save_dir = save_dir 177 | if not osp.exists( self.save_dir ): 178 | os.mkdir( self.save_dir ) 179 | 180 | self.best_eval = None 181 | 182 | def start_epoch(self): 183 | pass 184 | 185 | def finish_epoch(self): 186 | return {} 187 | 188 | def pretrain(self): 189 | pass 190 | 191 | def update_per_epoch(self): 192 | pass 193 | 194 | def snapshot(self, prefix, epoch): 195 | for name, network in self.snapshot_networks: 196 | model_file_name="model_{}_{}.pth".format(name, epoch) 197 | model_path=osp.join(prefix, model_file_name) 198 | torch.save(network.state_dict(), model_path) 199 | 200 | def train(self,epoch): 201 | if epoch==1: 202 | # self.pf.load_state_dict(torch.load('/root/metaworld-master/newsoftmodule_24/model'+str(index)+'/model_pf_best.pth')) 203 | # self.qf.load_state_dict(torch.load('/root/metaworld-master/newsoftmodule_24/model'+str(index)+'/model_qf_best.pth')) 204 | # self.vf.load_state_dict(torch.load('/root/metaworld-master/newsoftmodule_24/model'+str(index)+'/model_vf_best.pth')) 205 | 206 | self.pretrain() 207 | self.total_frames = 0 208 | if hasattr(self, "pretrain_frames"): 209 | self.total_frames = self.pretrain_frames 210 | 211 | self.start_epoch() 212 | 213 | self.current_epoch = epoch 214 | start = time.time() 215 | 216 | self.start_epoch() 217 | 218 | explore_start_time = time.time() 219 | training_epoch_info = self.collector.train_one_epoch() 220 | for reward in training_epoch_info["train_rewards"]: 221 | self.training_episode_rewards.append(reward) 222 | explore_time = time.time() - explore_start_time 223 | 224 | train_start_time = time.time() 225 | loss=self.update_per_epoch() 226 | train_time = time.time() - train_start_time 227 | 228 | finish_epoch_info = self.finish_epoch() 229 | 230 | eval_start_time = time.time() 231 | eval_infos = self.collector.eval_one_epoch() 232 | eval_time = time.time() - eval_start_time 233 | 234 | self.total_frames += self.collector.active_worker_nums * self.epoch_frames 235 | 236 | infos = {} 237 | 238 | for reward in eval_infos["eval_rewards"]: 239 | self.episode_rewards.append(reward) 240 | # del eval_infos["eval_rewards"] 241 | 242 | if self.best_eval is None or \ 243 | np.mean(eval_infos["eval_rewards"]) > self.best_eval: 244 | self.best_eval = np.mean(eval_infos["eval_rewards"]) 245 | self.snapshot(self.save_dir, 'best') 246 | del eval_infos["eval_rewards"] 247 | infos["eval_avg_success_rate"] =eval_infos["success"] 248 | infos["Running_Average_Rewards"] = np.mean(self.episode_rewards) 249 | infos["Running_success_rate"] =training_epoch_info["train_success_rate"] 250 | infos["Train_Epoch_Reward"] = training_epoch_info["train_epoch_reward"] 251 | infos["Running_Training_Average_Rewards"] = np.mean( 252 | self.training_episode_rewards) 253 | infos["Explore_Time"] = explore_time 254 | infos["Train___Time"] = train_time 255 | infos["Eval____Time"] = eval_time 256 | infos.update(eval_infos) 257 | infos.update(finish_epoch_info) 258 | 259 | self.logger.add_epoch_info(epoch, self.total_frames, 260 | time.time() - start, infos ) 261 | 262 | if epoch % self.save_interval == 0: 263 | self.snapshot(self.save_dir, epoch) 264 | if epoch==self.num_epochs-1: 265 | self.snapshot(self.save_dir, "finish") 266 | self.collector.terminate() 267 | return loss 268 | def update(self, batch): 269 | raise NotImplementedError 270 | 271 | def _update_target_networks(self): 272 | if self.use_soft_update: 273 | for net, target_net in self.target_networks: 274 | atu.soft_update_from_to(net, target_net, self.tau) 275 | else: 276 | if self.training_update_num % self.target_hard_update_period == 0: 277 | for net, target_net in self.target_networks: 278 | atu.copy_model_params_from_to(net, target_net) 279 | 280 | @property 281 | def networks(self): 282 | return [ 283 | ] 284 | 285 | @property 286 | def snapshot_networks(self): 287 | return [ 288 | ] 289 | 290 | @property 291 | def target_networks(self): 292 | return [ 293 | ] 294 | 295 | def to(self, device): 296 | for net in self.networks: 297 | net.to(device) 298 | 299 | class OffRLAlgo(RLAlgo): 300 | """ 301 | Base RL Algorithm Framework 302 | """ 303 | def __init__(self, 304 | 305 | pretrain_epochs=0, 306 | 307 | min_pool = 0, 308 | 309 | target_hard_update_period = 1000, 310 | use_soft_update = True, 311 | tau = 0.001, 312 | opt_times = 1, 313 | 314 | **kwargs 315 | ): 316 | super(OffRLAlgo, self).__init__(**kwargs) 317 | 318 | # environment relevant information 319 | self.pretrain_epochs = pretrain_epochs 320 | 321 | # target_network update information 322 | self.target_hard_update_period = target_hard_update_period 323 | self.use_soft_update = use_soft_update 324 | self.tau = tau 325 | 326 | # training information 327 | self.opt_times = opt_times 328 | self.min_pool = min_pool 329 | 330 | self.sample_key = [ "obs", "next_obs", "acts", "rewards", "terminals" ] 331 | 332 | def update_per_timestep(self): 333 | if self.replay_buffer.num_steps_can_sample() > max( self.min_pool, self.batch_size ): 334 | for _ in range( self.opt_times ): 335 | batch = self.replay_buffer.random_batch(self.batch_size, self.sample_key) 336 | infos = self.update( batch ) 337 | self.logger.add_update_info( infos ) 338 | 339 | def update_per_epoch(self): 340 | loss=[] 341 | for _ in range( self.opt_times ): 342 | batch = self.replay_buffer.random_batch(self.batch_size, self.sample_key) 343 | infos = self.update( batch ) 344 | loss.append(infos['Training/policy_loss']) 345 | self.logger.add_update_info( infos ) 346 | return np.mean(loss) 347 | def pretrain(self): 348 | total_frames = 0 349 | self.pretrain_epochs * self.collector.worker_nums * self.epoch_frames 350 | 351 | for pretrain_epoch in range( self.pretrain_epochs ): 352 | 353 | start = time.time() 354 | 355 | self.start_epoch() 356 | 357 | training_epoch_info = self.collector.train_one_epoch() 358 | for reward in training_epoch_info["train_rewards"]: 359 | self.training_episode_rewards.append(reward) 360 | 361 | finish_epoch_info = self.finish_epoch() 362 | 363 | total_frames += self.collector.active_worker_nums * self.epoch_frames 364 | 365 | infos = {} 366 | 367 | infos["Train_Epoch_Reward"] = training_epoch_info["train_epoch_reward"] 368 | infos["Running_Training_Average_Rewards"] = np.mean(self.training_episode_rewards) 369 | infos.update(finish_epoch_info) 370 | 371 | self.logger.add_epoch_info(pretrain_epoch, total_frames, time.time() - start, infos, csv_write=False ) 372 | 373 | self.pretrain_frames = total_frames 374 | 375 | self.logger.log("Finished Pretrain") 376 | 377 | class SAC(OffRLAlgo): 378 | """ 379 | SAC 380 | """ 381 | def __init__( 382 | self, 383 | pf, vf, qf, 384 | plr,vlr,qlr, 385 | optimizer_class=optim.Adam, 386 | 387 | policy_std_reg_weight=1e-3, 388 | policy_mean_reg_weight=1e-3, 389 | 390 | reparameterization = True, 391 | automatic_entropy_tuning = True, 392 | target_entropy = None, 393 | **kwargs 394 | ): 395 | super(SAC, self).__init__(**kwargs) 396 | self.pf = pf 397 | self.qf = qf 398 | self.vf = vf 399 | self.target_vf = copy.deepcopy(vf) 400 | self.to(self.device) 401 | 402 | self.plr = plr 403 | self.vlr = vlr 404 | self.qlr = qlr 405 | 406 | self.qf_optimizer = optimizer_class( 407 | self.qf.parameters(), 408 | lr=self.qlr, 409 | ) 410 | 411 | self.vf_optimizer = optimizer_class( 412 | self.vf.parameters(), 413 | lr=self.vlr, 414 | ) 415 | 416 | self.pf_optimizer = optimizer_class( 417 | self.pf.parameters(), 418 | lr=self.plr, 419 | ) 420 | 421 | self.automatic_entropy_tuning = automatic_entropy_tuning 422 | if self.automatic_entropy_tuning: 423 | if target_entropy: 424 | self.target_entropy = target_entropy 425 | else: 426 | self.target_entropy = -np.prod(self.env.action_space.shape).item() # from rlkit 427 | self.log_alpha = torch.zeros(1).to(self.device) 428 | self.log_alpha.requires_grad_() 429 | self.alpha_optimizer = optimizer_class( 430 | [self.log_alpha], 431 | lr=self.plr, 432 | ) 433 | 434 | self.qf_criterion = nn.MSELoss() 435 | self.vf_criterion = nn.MSELoss() 436 | 437 | self.policy_std_reg_weight = policy_std_reg_weight 438 | self.policy_mean_reg_weight = policy_mean_reg_weight 439 | 440 | self.reparameterization = reparameterization 441 | 442 | def update(self, batch): 443 | self.training_update_num += 1 444 | 445 | obs = batch['obs'] 446 | actions = batch['acts'] 447 | next_obs = batch['next_obs'] 448 | rewards = batch['rewards'] 449 | terminals = batch['terminals'] 450 | 451 | rewards = torch.Tensor(rewards).to( self.device ) 452 | terminals = torch.Tensor(terminals).to( self.device ) 453 | obs = torch.Tensor(obs).to( self.device ) 454 | actions = torch.Tensor(actions).to( self.device ) 455 | next_obs = torch.Tensor(next_obs).to( self.device ) 456 | 457 | """ 458 | Policy operations. 459 | """ 460 | sample_info = self.pf.explore(obs, return_log_probs=True ) 461 | 462 | mean = sample_info["mean"] 463 | log_std = sample_info["log_std"] 464 | new_actions = sample_info["action"] 465 | log_probs = sample_info["log_prob"] 466 | ent = sample_info["ent"] 467 | 468 | q_pred = self.qf([obs, actions]) 469 | v_pred = self.vf(obs) 470 | 471 | if self.automatic_entropy_tuning: 472 | """ 473 | Alpha Loss 474 | """ 475 | alpha_loss = -(self.log_alpha * (log_probs + self.target_entropy).detach()).mean() 476 | self.alpha_optimizer.zero_grad() 477 | alpha_loss.backward() 478 | self.alpha_optimizer.step() 479 | alpha = self.log_alpha.exp() 480 | else: 481 | alpha = 1 482 | alpha_loss = 0 483 | 484 | """ 485 | QF Loss 486 | """ 487 | target_v_values = self.target_vf(next_obs) 488 | q_target = rewards + (1. - terminals) * self.discount * target_v_values 489 | qf_loss = self.qf_criterion( q_pred, q_target.detach()) 490 | 491 | """ 492 | VF Loss 493 | """ 494 | q_new_actions = self.qf([obs, new_actions]) 495 | v_target = q_new_actions - alpha * log_probs 496 | vf_loss = self.vf_criterion( v_pred, v_target.detach()) 497 | 498 | """ 499 | Policy Loss 500 | """ 501 | if not self.reparameterization: 502 | log_policy_target = q_new_actions - v_pred 503 | policy_loss = ( 504 | log_probs * ( alpha * log_probs - log_policy_target).detach() 505 | ).mean() 506 | else: 507 | policy_loss = ( alpha * log_probs - q_new_actions).mean() 508 | 509 | std_reg_loss = self.policy_std_reg_weight * (log_std**2).mean() 510 | mean_reg_loss = self.policy_mean_reg_weight * (mean**2).mean() 511 | 512 | policy_loss += std_reg_loss + mean_reg_loss 513 | 514 | """ 515 | Update Networks 516 | """ 517 | self.pf_optimizer.zero_grad() 518 | 519 | w=[] 520 | for key in pfs[0].state_dict().keys(): 521 | w.append(torch.cat([pfs[j].state_dict()[key].unsqueeze(0) for j in range(len(envs))])) 522 | 523 | rloss[index] = policy_loss.clone().detach().item() 524 | if multitask: 525 | pre=rloss[index]+lossw(currindex,index,rloss,w,B)/10 526 | else: 527 | pre=rloss[index] 528 | # compute gradients 529 | 530 | pre.backward() 531 | 532 | # train the NN 533 | 534 | self.pf_optimizer.step() 535 | rloss[index]=rloss[index].detach().item() 536 | 537 | self.qf_optimizer.zero_grad() 538 | qf_loss.backward() 539 | self.qf_optimizer.step() 540 | 541 | self.vf_optimizer.zero_grad() 542 | vf_loss.backward() 543 | self.vf_optimizer.step() 544 | 545 | self._update_target_networks() 546 | 547 | # Information For Logger 548 | info = {} 549 | info['Reward_Mean'] = rewards.mean().item() 550 | 551 | if self.automatic_entropy_tuning: 552 | info["Alpha"] = alpha.item() 553 | info["Alpha_loss"] = alpha_loss.item() 554 | info['Training/policy_loss'] = policy_loss.item() 555 | info['Training/vf_loss'] = vf_loss.item() 556 | info['Training/qf_loss'] = qf_loss.item() 557 | 558 | info['log_std/mean'] = log_std.mean().item() 559 | info['log_std/std'] = log_std.std().item() 560 | info['log_std/max'] = log_std.max().item() 561 | info['log_std/min'] = log_std.min().item() 562 | 563 | info['log_probs/mean'] = log_std.mean().item() 564 | info['log_probs/std'] = log_std.std().item() 565 | info['log_probs/max'] = log_std.max().item() 566 | info['log_probs/min'] = log_std.min().item() 567 | 568 | info['mean/mean'] = mean.mean().item() 569 | info['mean/std'] = mean.std().item() 570 | info['mean/max'] = mean.max().item() 571 | info['mean/min'] = mean.min().item() 572 | 573 | return info 574 | 575 | @property 576 | def networks(self): 577 | return [ 578 | self.pf, 579 | self.qf, 580 | self.vf, 581 | self.target_vf 582 | ] 583 | 584 | @property 585 | def snapshot_networks(self): 586 | return [ 587 | ["pf", self.pf], 588 | ["qf", self.qf], 589 | ["vf", self.vf] 590 | ] 591 | 592 | @property 593 | def target_networks(self): 594 | return [ 595 | ( self.vf, self.target_vf ) 596 | ] 597 | 598 | class parser: 599 | def __init__(self): 600 | self.config='config/sac_ant.json' 601 | self.id='mt10' 602 | self.worker_nums=10 603 | self.eval_worker_nums=10 604 | self.seed=20 605 | self.vec_env_nums=1 606 | self.save_dir='./save/sac_ant' 607 | self.log_dir='./log/sac_ant' 608 | self.no_cuda=True 609 | self.overwrite=True 610 | self.device='cpu' 611 | self.cuda=False 612 | 613 | text="""{ 614 | "env_name" : "mt10", 615 | "env":{ 616 | "reward_scale":1, 617 | "obs_norm":false 618 | }, 619 | "meta_env":{ 620 | "obs_type": "with_goal_and_id" 621 | }, 622 | "replay_buffer":{ 623 | "size": 1e6 624 | }, 625 | "net":{ 626 | "hidden_shapes": [128,64,32,16], 627 | "append_hidden_shapes":[] 628 | }, 629 | "general_setting": { 630 | "discount" : 0.99, 631 | "pretrain_epochs" : 20, 632 | "num_epochs" : 7500, 633 | "epoch_frames" : 200, 634 | "max_episode_frames" : 200, 635 | 636 | "batch_size" : 1280, 637 | "min_pool" : 10000, 638 | 639 | "target_hard_update_period" : 1000, 640 | "use_soft_update" : true, 641 | "tau" : 0.005, 642 | "opt_times" : 200, 643 | 644 | "eval_episodes" : 3 645 | }, 646 | "sac":{ 647 | } 648 | }""" 649 | 650 | ml10 = metaworld.MT10() # Construct the benchmark, sampling tasks 651 | envs = [] 652 | for name, env_cls in ml10.train_classes.items(): 653 | env = env_cls() 654 | task = random.choice([task for task in ml10.train_tasks 655 | if task.env_name == name]) 656 | env.set_task(task) 657 | envs.append(env) 658 | 659 | !mkdir config 660 | with open('config/sac_ant.json','w') as f: 661 | f.write(text) 662 | 663 | args=parser() 664 | params = get_params(args.config) 665 | device = torch.device( 666 | "cuda:{}".format(args.device) if args.cuda else "cpu") 667 | if args.cuda: 668 | torch.cuda.manual_seed_all(args.seed) 669 | torch.backends.cudnn.benchmark = False 670 | torch.backends.cudnn.deterministic = True 671 | env.seed(args.seed) 672 | torch.manual_seed(args.seed) 673 | np.random.seed(args.seed) 674 | random.seed(args.seed) 675 | 676 | normalizer=Normalizer(env.observation_space.shape) 677 | buffer_param = params['replay_buffer'] 678 | experiment_name = os.path.split( 679 | os.path.splitext(args.config)[0])[-1] if args.id is None \ 680 | else args.id 681 | logger = Logger( 682 | experiment_name, params['env_name'], args.seed, params, args.log_dir) 683 | 684 | pfs=[] 685 | qf1s=[] 686 | vfs=[] 687 | agents=[] 688 | epochs=[1 for i in range(len(envs))] 689 | for index in range(len(envs)): 690 | print(index) 691 | env=SingleWrapper(envs[index]) 692 | params = get_params(args.config) 693 | params['general_setting']['logger'] = Logger( 694 | 'mt10', str(index), args.seed, params, './log/mt10_'+str(index)+'/') 695 | params['env_name']=str(index) 696 | params['general_setting']['env'] = env 697 | 698 | replay_buffer = BaseReplayBuffer( 699 | max_replay_buffer_size=int(buffer_param['size'])#, 700 | # time_limit_filter=buffer_param['time_limit_filter'] 701 | ) 702 | params['general_setting']['replay_buffer'] = replay_buffer 703 | params['general_setting']['device'] = device 704 | params['net']['base_type'] = networks.MLPBase 705 | params['net']['activation_func'] = torch.nn.ReLU 706 | 707 | pf = policies.GuassianContPolicy( 708 | input_shape=env.observation_space.shape[0], 709 | output_shape=2 * env.action_space.shape[0], 710 | **params['net'], 711 | **params['sac']) 712 | 713 | qf1 = networks.QNet( 714 | input_shape=env.observation_space.shape[0] + env.action_space.shape[0], 715 | output_shape=1, 716 | **params['net']) 717 | 718 | vf = networks.Net( 719 | input_shape=env.observation_space.shape, 720 | output_shape=1, 721 | **params['net'] 722 | ) 723 | pfs.append(pf) 724 | qf1s.append(qf1) 725 | vfs.append(vf) 726 | params['general_setting']['collector'] = BaseCollector( 727 | env=env, pf=pf, 728 | replay_buffer=replay_buffer, device=device, 729 | train_render=False 730 | ) 731 | params['general_setting']['save_dir'] = osp.join( 732 | './log/', "model10_"+str(index)) 733 | agent = SAC( 734 | pf=pf, 735 | qf=qf1,plr=3e-4,vlr=3e-4,qlr=3e-4, 736 | vf=vf, 737 | **params["sac"], 738 | **params["general_setting"] 739 | ) 740 | agents.append(agent) 741 | 742 | # If we do not save prior information in advance, we could delete priors and term2. 743 | priors = np.load('priors.npy') 744 | # differentiable ranking loss 745 | def pss(x,points): 746 | def pss0(x,i): 747 | return torch.tanh(200*torch.tensor(x-i))/2+0.5 748 | return len(points)-sum([pss0(x,i) for i in points]) 749 | 750 | def losst(currindex,t,rloss,w,B,mu=0.2,lamb=[0.01,0.01,0.01],U=[13],pi=list(range(len(envs)))): 751 | new_rloss=[i for i in rloss] 752 | new_rloss[t]=new_rloss[t]+1 753 | rlossRank=1+len(envs)-rankdata(new_rloss, method='min') 754 | points=B[t] 755 | sim=[sum(nn.CosineSimilarity()(pfs[t].state_dict()['last.weight'].view(-1,1),pfs[i].state_dict()['last.weight'].view(-1,1))) for i in range(len(envs))] 756 | sim[t]=sim[t]+100 757 | rlossRank2=1+len(envs)-rankdata(sim, method='min') 758 | term0=(1+mu*sum([torch.norm(torch.tensor(B[t][i]),p=1)for i in set(list(range(len(envs))))-set([t])]))*rloss[t] 759 | term1=sum([sum([sum([torch.norm(w[i][s]-sum([B[pi[j]][s]*w[i][pi[j]] for j in range(currindex-1)])-B[t][s]*w[i][t],p=2)**2]) for i in range(len(pfs[0].state_dict().keys()))]) for s in U]) 760 | term2=torch.norm(torch.tensor(priors[current])-torch.tensor([pss(torch.tensor(i-0.01),points) for i in points]))**2 761 | term3=torch.norm(torch.tensor(rlossRank)-torch.tensor([pss(torch.tensor(i-0.01),points) for i in points]))**2 762 | term4=torch.norm(torch.tensor(rlossRank2)-torch.tensor([pss(torch.tensor(i-0.01),points) for i in points]))**2 763 | terms_history[0].append(term0.detach().numpy()) 764 | terms_history[1].append(term1) 765 | terms_history[2].append(term2.detach().numpy()) 766 | terms_history[3].append(term3.detach().numpy()) 767 | terms_history[4].append(term4.detach().numpy()) 768 | if len(terms_history[0])<=1: 769 | return term0+term1+term2+term3+term4 770 | else: 771 | return (1/(np.array(terms_history[0]).std())**2)*term0+(1/(np.array(terms_history[1]).std())**2)*term1+(1/(np.array(terms_history[3]).std())**2)*term3+(1/(np.array(terms_history[4]).std())**2)*term4+np.log((np.array(terms_history[0]).std())*(np.array(terms_history[1]).std())*(np.array(terms_history[2]).std())*(np.array(terms_history[3]).std())*(np.array(terms_history[4]).std())) 772 | 773 | def lossb(currindex,t,rloss,w,B,mu=0.2,lamb=[0.01,0.01,0.01],U=[13],pi=list(range(len(envs)))): 774 | new_rloss=[i for i in rloss] 775 | new_rloss[t]=new_rloss[t]+1 776 | rlossRank=1+len(envs)-rankdata(new_rloss, method='min') 777 | points=B[t] 778 | sim=[sum(nn.CosineSimilarity()(pfs[t].state_dict()['last.weight'].view(-1,1),pfs[i].state_dict()['last.weight'].view(-1,1))) for i in range(len(envs))] 779 | sim[t]=sim[t]+100 780 | rlossRank2=1+len(envs)-rankdata(sim, method='min') 781 | term0=(1+mu*sum([torch.norm(B[t][i],p=1)for i in set(list(range(len(envs))))-set([t])]))*rloss[t] 782 | term1=sum([sum([sum([torch.norm(w[i][s]-sum([B[pi[j]][s]*w[i][pi[j]] for j in range(currindex-1)])-B[t][s]*w[i][t],p=2)**2]) for i in range(len(pfs[0].state_dict().keys()))]) for s in U]) 783 | term2=torch.norm(torch.tensor(priors[current])-torch.tensor([pss(torch.tensor(i-0.01),points) for i in points]))**2 784 | term3=torch.norm(torch.tensor(rlossRank)-torch.tensor([pss(torch.tensor(i-0.01),points) for i in points]))**2 785 | term4=torch.norm(torch.tensor(rlossRank2)-torch.tensor([pss(torch.tensor(i-0.01),points) for i in points]))**2 786 | 787 | if len(terms_history[0])<=1: 788 | return term0+term1+term2+term3+term4 789 | else: 790 | return (1/(np.array(terms_history[0]).std())**2)*term0+(1/(np.array(terms_history[1]).std())**2)*term1+(1/(np.array(terms_history[3]).std())**2)*term3+(1/(np.array(terms_history[4]).std())**2)*term4+np.log((np.array(terms_history[0]).std())*(np.array(terms_history[1]).std())*(np.array(terms_history[2]).std())*(np.array(terms_history[3]).std())*(np.array(terms_history[4]).std())) 791 | 792 | def lossw(currindex,t,rloss,w,B,mu=0.2,lamb=[0.01,0.01,0.01],U=[13],pi=list(range(len(envs)))): 793 | new_rloss=[i for i in rloss] 794 | new_rloss[t]=new_rloss[t]+1 795 | rlossRank=1+len(envs)-rankdata(new_rloss, method='min') 796 | points=B[t] 797 | term0=(1+mu*sum([torch.norm(torch.tensor(B[t][i]),p=1)for i in set(list(range(len(envs))))-set([t])]))*rloss[t] 798 | sim=[sum(nn.CosineSimilarity()(pfs[t].state_dict()['last.weight'].view(-1,1),pfs[i].state_dict()['last.weight'].view(-1,1))) for i in range(len(envs))] 799 | sim[t]=sim[t]+100 800 | rlossRank2=1+len(envs)-rankdata(sim, method='min') 801 | 802 | term1=sum([sum([torch.norm(w[i][t]-sum([B[pi[j]][t]*w[i][pi[j]] for j in range(currindex-1)]),p=2)**2]) for i in range(len(pfs[0].state_dict().keys()))]) 803 | term2=torch.norm(torch.tensor(priors[current])-torch.tensor([pss(torch.tensor(i-0.01),points) for i in points]))**2 804 | term3=torch.norm(torch.tensor(rlossRank)-torch.tensor([pss(torch.tensor(i-0.01),points) for i in points]))**2 805 | term4=torch.norm(torch.tensor(rlossRank2)-torch.tensor([pss(torch.tensor(i-0.01),points) for i in points]))**2 806 | if len(terms_history[0])<=1: 807 | return term0+term1+term2+term3+term4 808 | else: 809 | return (1/(np.array(terms_history[0]).std())**2)*term0+(1/(np.array(terms_history[1]).std())**2)*term1+(1/(np.array(terms_history[3]).std())**2)*term3+(1/(np.array(terms_history[4]).std())**2)*term4+np.log((np.array(terms_history[0]).std())*(np.array(terms_history[1]).std())*(np.array(terms_history[2]).std())*(np.array(terms_history[3]).std())*(np.array(terms_history[4]).std())) 810 | 811 | #FrankWolfe 812 | OPTIMIZER_CLASSES = [FrankWolfe]# [PGD, PGDMadry, FrankWolfe, MomentumFrankWolfe] 813 | radius=0.05 814 | 815 | def setup_problem(make_nonconvex=False): 816 | radius2 = radius 817 | loss_func=lossb 818 | constraint = LinfBall(radius2) 819 | return loss_func, constraint 820 | 821 | def optimize(loss_func, constraint, optimizer_class, iterations=100): 822 | for i in range(len(envs)): 823 | if i!=t: 824 | B[t][i] =torch.tensor(B[t][i],requires_grad=True) 825 | optimizer = [optimizer_class([B[t][i]], constraint) for i in set(list(range(len(envs))))-set([t])] 826 | iterates = [[B[t][i].data if i!=t else B[t][i] for i in range(len(envs))]] 827 | losses = [] 828 | # Use Madry's heuristic for step size 829 | step_size = { 830 | FrankWolfe.name: None, 831 | MomentumFrankWolfe.name: None, 832 | PGD.name: 2.5 * constraint.alpha / iterations * 2., 833 | PGDMadry.name: 2.5 * constraint.alpha / iterations 834 | } 835 | for _ in range(iterations): 836 | for i in range(len(envs)-1): 837 | optimizer[i].zero_grad() 838 | loss = loss_func(currindex,t,rloss,w,B,U=list(set(U)-set(list([t])))) 839 | loss.backward(retain_graph=True) 840 | for i in range(len(envs)-1): 841 | optimizer[i].step(step_size[optimizer[i].name]) 842 | for i in set(list(range(len(envs))))-set([t]): 843 | B[t][i].data.clamp_(0,100) 844 | losses.append(loss) 845 | iterates.append([B[t][i].data if i!=t else B[t][i] for i in range(len(envs))]) 846 | loss = loss_func(currindex,t,rloss,w,B,U=list(set(U)-set(list([t])))).detach() 847 | losses.append(loss) 848 | B[t]=[B[t][i].data if i!=t else B[t][i] for i in range(len(envs))] 849 | return losses, iterates 850 | 851 | multitask=False 852 | mu,lamb=0.01,[0.01,0.01,0.01]#para[0],[para[1],para[2],para[3]] 853 | terms_history=[[]]*5 854 | rloss=[0.0 for i in range(len(envs))] 855 | rewardsRec=[[] for i in range(len(envs))] 856 | rewardsRec_nor=[[0] for i in range(len(envs))] 857 | succeessRec=[[] for i in range(len(envs))] 858 | B=[list(i) for i in np.diag(np.ones(len(envs)))] 859 | TotalLossRec=[] 860 | TotalLossRec_withoutnormalization=[] 861 | indexVlist=[] 862 | 863 | for i_episode in range(10): 864 | for index, env in enumerate(envs): 865 | agents[index].train(epochs[index]) 866 | epochs[index]+=1 867 | rlosscopy=rloss.copy() 868 | TotalLossRec_withoutnormalization.append(rlosscopy) 869 | if TotalLossRec: 870 | TotalLossRec.append((rlosscopy-np.array(TotalLossRec_withoutnormalization).mean(axis=0))/(1e-5+np.array(TotalLossRec_withoutnormalization).std(axis=0))) 871 | else: 872 | TotalLossRec.append(rlosscopy) 873 | 874 | for i_episode in range(10,10000): 875 | rlosscopy=rloss.copy() 876 | low=np.array(rlosscopy).mean()-2*np.array(rlosscopy).std() 877 | high=np.array(rlosscopy).mean()+2*np.array(rlosscopy).std() 878 | indexV=sum(np.array(rlosscopy)high)/len(envs)/3+np.exp(-i_episode/1000)/3+np.exp(-np.array(TotalLossRec[-1]).mean()/40)/3 879 | indexVlist.append([indexV,sum(np.array(rlosscopy)high)/len(envs)/3,np.exp(-i_episode/1000)/3,np.exp(-np.array(TotalLossRec[-1]).mean()/40)/3]) 880 | 881 | if np.random.random()