├── README.md ├── core_gogepo.py ├── gogepo.py └── requirements.txt /README.md: -------------------------------------------------------------------------------- 1 | ### Goal-Conditioned Generators of Deep Policies 2 | 3 | [arxiv](https://arxiv.org/abs/2207.01570) 4 | 5 | Install required python packages: 6 | ```bash 7 | pip install -r requirements.txt 8 | ``` 9 | 10 | To reproduce the main results, run GoGePo in different environments: 11 | ```bash 12 | python3 gogepo.py --env_name Swimmer-v3 --use_gpu 1 13 | python3 gogepo.py --env_name Hopper-v3 --use_gpu 1 14 | python3 gogepo.py --env_name InvertedPendulum-v2 --use_gpu 1 15 | python3 gogepo.py --env_name MountainCarContinuous-v0 --use_gpu 1 16 | ``` 17 | -------------------------------------------------------------------------------- /core_gogepo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.distributions.normal import Normal 5 | from torch.distributions.categorical import Categorical 6 | from gym.spaces import Box, Discrete 7 | import numpy as np 8 | import random 9 | from collections import OrderedDict 10 | 11 | 12 | def mlp(sizes, activation, output_activation=nn.Identity): 13 | layers = [] 14 | for j in range(len(sizes) - 1): 15 | act = activation if j < len(sizes) - 2 else output_activation 16 | layers += [nn.Linear(sizes[j], sizes[j + 1]), act()] 17 | 18 | return nn.Sequential(*layers) 19 | 20 | 21 | class VirtualModule: 22 | def __init__(self): 23 | self._parameter_shapes = self.get_parameter_shapes() 24 | 25 | self._num_parameters = 0 26 | for shape in self.parameter_shapes.values(): 27 | numel = np.prod(shape) 28 | self._num_parameters += numel 29 | 30 | def get_parameter_shapes(self): 31 | # return an OrderedDict with the parameter names and their shape 32 | return NotImplementedError 33 | 34 | def parameter_initialization(self, num_instances): 35 | factor = 1 / ((self.num_parameters / 100) ** 0.5) 36 | initializations = [] 37 | for i in range(num_instances): 38 | p = [] 39 | for key, shape in self.parameter_shapes.items(): 40 | p.append(torch.randn(shape).view(-1) * factor) 41 | p = torch.cat(p, dim=0) 42 | initializations.append(p) 43 | return initializations 44 | 45 | def split_parameters(self, p): 46 | if len(p.shape) == 1: 47 | batch_size = [] 48 | else: 49 | batch_size = [p.shape[0]] 50 | pointer = 0 51 | parameters = [] 52 | for shape in self.parameter_shapes.values(): 53 | numel = np.prod(shape) 54 | x = p[..., pointer : pointer + numel].view(*(batch_size + list(shape))) 55 | parameters.append(x) 56 | pointer += numel 57 | return parameters 58 | 59 | @property 60 | def parameter_shapes(self): 61 | return self._parameter_shapes 62 | 63 | @property 64 | def num_parameters(self): 65 | return self._num_parameters 66 | 67 | 68 | class VirtualModuleWrapper(torch.nn.Module): 69 | # Allows treating a virtual module as a normal pytorch module (train with standard optimizers etc.) 70 | def __init__(self, virtual_module): 71 | super().__init__() 72 | self.virtual_module = virtual_module 73 | self.virtual_parameters = torch.nn.Parameter( 74 | self.virtual_module.parameter_initialization(1)[0] 75 | ) 76 | 77 | def forward(self, x): 78 | output = self.virtual_module.forward(x, self.virtual_parameters) 79 | return output 80 | 81 | 82 | def linear_multi_parameter(input, weight, bias=None): 83 | """ 84 | n: input batch dimension 85 | m: parameter batch dimension (not obligatory) 86 | i: input feature dimension 87 | o: output feature dimension 88 | :param input: n x (m x) i 89 | :param weight: (m x) o x i 90 | :param bias: (m x) o 91 | :return: n x (m x) o 92 | """ 93 | 94 | if len(weight.shape) == 2: 95 | # no parameter batch dimension 96 | x = torch.einsum("ni,oi->no", input, weight) 97 | elif len(input.shape) == 3: 98 | # parameter batch dimension for input and weights 99 | x = torch.einsum("nmi,moi->nmo", input, weight) 100 | else: 101 | # no parameter dimension batch for input 102 | x = torch.einsum("ni,moi->nmo", input, weight) 103 | if bias is not None: 104 | x = x + bias.unsqueeze(0) 105 | return x 106 | 107 | 108 | class VirtualMLP(VirtualModule): 109 | def __init__(self, layer_sizes, nonlinearity="tanh", output_activation="linear"): 110 | self.layer_sizes = layer_sizes 111 | 112 | if nonlinearity == "tanh": 113 | self.nonlinearity = torch.tanh 114 | elif nonlinearity == "sigmoid": 115 | self.nonlinearity = torch.sigmoid 116 | else: 117 | self.nonlinearity = torch.relu 118 | 119 | if output_activation == "linear": 120 | self.output_activation = None 121 | elif output_activation == "sigmoid": 122 | self.output_activation = torch.sigmoid 123 | elif output_activation == "tanh": 124 | self.output_activation = torch.tanh 125 | elif output_activation == "softmax": 126 | self.output_activation = lambda x: torch.softmax(x, dim=-1) 127 | 128 | super(VirtualMLP, self).__init__() 129 | 130 | def get_parameter_shapes(self): 131 | parameter_shapes = OrderedDict() 132 | for i in range(1, len(self.layer_sizes)): 133 | parameter_shapes["w" + str(i)] = ( 134 | self.layer_sizes[i], 135 | self.layer_sizes[i - 1], 136 | ) 137 | parameter_shapes["wb" + str(i)] = (self.layer_sizes[i],) 138 | 139 | return parameter_shapes 140 | 141 | def forward(self, input, parameters, callback_func=None): 142 | # input_sequence: input_batch x (parameter_batch x) input_size 143 | # parameters: (parameter_batch x) num_params 144 | # return: input_batch x (parameter_batch x) output_size 145 | p = self.split_parameters(parameters) 146 | num_layers = len(self.layer_sizes) - 1 147 | x = input 148 | for l in range(0, num_layers): 149 | w = p[l * 2] 150 | a = linear_multi_parameter(x, w, bias=p[l * 2 + 1]) 151 | if l < num_layers - 1: 152 | x = self.nonlinearity(a) 153 | if callback_func is not None: 154 | callback_func(x, l) 155 | else: 156 | x = a if self.output_activation is None else self.output_activation(a) 157 | return x 158 | 159 | def parameter_initialization(self, num_instances, bias_var=0.0): 160 | initializations = [] 161 | for i in range(num_instances): 162 | p = [] 163 | for i in range(1, len(self.layer_sizes)): 164 | w = torch.empty(self.parameter_shapes["w" + str(i)]) 165 | torch.nn.init.xavier_normal_(w) 166 | p.append(w.view(-1)) 167 | if self.bias: 168 | for i in range(1, len(self.layer_sizes)): 169 | b = torch.empty(self.parameter_shapes["wb" + str(i)]) 170 | if bias_var == 0: 171 | torch.nn.init.zeros_(b) 172 | else: 173 | torch.nn.init.normal_(b, std=bias_var**0.5) 174 | p.append(b.view(-1)) 175 | p = torch.cat(p, dim=0) 176 | initializations.append(p) 177 | return initializations 178 | 179 | 180 | class VirtualMLPPolicy(VirtualMLP): 181 | def __init__(self, layer_sizes, bias=True, act_lim=1): 182 | super().__init__( 183 | layer_sizes=layer_sizes, nonlinearity="tanh", output_activation="tanh" 184 | ) 185 | self.act_lim = act_lim 186 | 187 | def forward(self, input, parameters, callback_func=None): 188 | x = super().forward(input, parameters, callback_func) 189 | x = x * self.act_lim 190 | return x 191 | 192 | 193 | def get_hypernetwork_mlp_generator( 194 | layer_sizes, 195 | hidden_sizes, 196 | embedding_dim, 197 | features_per_embedding=32, 198 | scale_layer_out=False, 199 | scale_parameter=1, 200 | ): 201 | input_hn = HyperNetwork( 202 | hidden_sizes=hidden_sizes, 203 | z_dim_w=embedding_dim + 1, 204 | z_dim_b=embedding_dim + 1, 205 | out_size_w=[ 206 | layer_sizes[1] if len(layer_sizes) == 2 else features_per_embedding, 207 | layer_sizes[0], 208 | ], 209 | out_size_b=layer_sizes[1] if len(layer_sizes) == 2 else features_per_embedding, 210 | ) 211 | 212 | if len(layer_sizes) > 2: 213 | output_hn = HyperNetwork( 214 | hidden_sizes=hidden_sizes, 215 | z_dim_w=embedding_dim + 1, 216 | z_dim_b=embedding_dim + 1, 217 | out_size_w=[layer_sizes[-1], features_per_embedding], 218 | out_size_b=layer_sizes[-1], 219 | ) 220 | else: 221 | output_hn = None 222 | 223 | if len(layer_sizes) > 3: 224 | hidden_hn = HyperNetwork( 225 | hidden_sizes=hidden_sizes, 226 | z_dim_w=embedding_dim + 1, 227 | z_dim_b=embedding_dim + 1, 228 | out_size_w=[features_per_embedding, features_per_embedding], 229 | out_size_b=features_per_embedding, 230 | ) 231 | else: 232 | hidden_hn = None 233 | 234 | in_tiling = [ 235 | 1, 236 | 1 if len(layer_sizes) == 2 else layer_sizes[1] // features_per_embedding, 237 | ] 238 | out_tiling = ( 239 | [layer_sizes[-2] // features_per_embedding, 1] 240 | if len(layer_sizes) >= 2 241 | else None 242 | ) 243 | if len(layer_sizes) > 3: 244 | hidden_tiling = [] 245 | for i in range(1, len(layer_sizes) - 2): 246 | ht = [ 247 | layer_sizes[i] // features_per_embedding, 248 | layer_sizes[i + 1] // features_per_embedding, 249 | ] 250 | hidden_tiling.append(ht) 251 | else: 252 | hidden_tiling = None 253 | 254 | fc_generator = HyperNetworkGenerator( 255 | input_fc_hn=input_hn, 256 | hidden_fc_hn=hidden_hn, 257 | output_fc_hn=output_hn, 258 | in_tiling=in_tiling, 259 | hidden_tiling=hidden_tiling, 260 | out_tiling=out_tiling, 261 | embedding_dim=embedding_dim, 262 | layer_sizes=layer_sizes, 263 | scale_layer_out=scale_layer_out, 264 | scale_parameter=scale_parameter, 265 | ) 266 | 267 | return fc_generator 268 | 269 | 270 | class HyperNetwork(nn.Module): 271 | def __init__( 272 | self, hidden_sizes, z_dim_w=65, z_dim_b=4, out_size_w=[8, 8], out_size_b=8 273 | ): 274 | super(HyperNetwork, self).__init__() 275 | self.z_dim_w = z_dim_w 276 | self.z_dim_b = z_dim_b 277 | 278 | self.out_size_w = out_size_w 279 | self.out_size_b = out_size_b 280 | self.total_el_w = self.out_size_w[0] * self.out_size_w[1] 281 | 282 | sizes_w = [self.z_dim_w] + list(hidden_sizes) + [self.total_el_w] 283 | self.net_w = mlp(sizes_w, activation=nn.ReLU) 284 | sizes_b = [self.z_dim_b] + list(hidden_sizes) + [self.out_size_b] 285 | self.net_b = mlp(sizes_b, activation=nn.ReLU) 286 | 287 | def forward(self, z, command): 288 | # z: batch_size x z_dim 289 | # command: batch_size x 1 290 | kernel_w = self.net_w(torch.cat((z, command), dim=1)) 291 | kernel_w = kernel_w.view(-1, self.out_size_w[0], self.out_size_w[1]) 292 | 293 | kernel_b = self.net_b(torch.cat((z, command), dim=1)) 294 | kernel_b = kernel_b.view(-1, self.out_size_b) 295 | 296 | return kernel_w, kernel_b 297 | 298 | 299 | class HyperNetworkGenerator(torch.nn.Module): 300 | def __init__( 301 | self, 302 | input_fc_hn: HyperNetwork, 303 | hidden_fc_hn: HyperNetwork = None, 304 | output_fc_hn: HyperNetwork = None, 305 | in_tiling=[1, 1], 306 | hidden_tiling=None, 307 | out_tiling=None, 308 | embedding_dim: int = 64, 309 | layer_sizes=None, 310 | scale_layer_out=False, 311 | scale_parameter=1, 312 | ): 313 | super().__init__() 314 | # layer generators 315 | self.input_hn = input_fc_hn 316 | self.hidden_hn = hidden_fc_hn 317 | self.output_hn = output_fc_hn 318 | # tilings 319 | self.in_tiling = in_tiling 320 | self.hidden_tiling = hidden_tiling 321 | self.out_tiling = out_tiling 322 | if layer_sizes is not None: 323 | self.layer_sizes = layer_sizes 324 | else: 325 | raise ValueError 326 | self.scale_layer_out = scale_layer_out 327 | self.scale_parameter = scale_parameter 328 | self.num_layers = 1 329 | if self.hidden_tiling is not None: 330 | self.num_layers += len(self.hidden_tiling) 331 | if self.out_tiling is not None: 332 | self.num_layers += 1 333 | 334 | # embeddings 335 | self.in_embeddings = torch.nn.Parameter( 336 | torch.randn(self.in_tiling + [embedding_dim]) 337 | ) 338 | self.out_embeddings = torch.nn.Parameter( 339 | torch.randn(self.out_tiling + [embedding_dim]) 340 | ) 341 | if self.num_layers >= 3: 342 | self.hidden_embeddings = torch.nn.ParameterList( 343 | [ 344 | torch.nn.Parameter(torch.randn(ft + [embedding_dim])) 345 | for ft in self.hidden_tiling 346 | ] 347 | ) 348 | else: 349 | self.hidden_embeddings = None 350 | 351 | def forward( 352 | self, command: torch.FloatTensor, conditioning: torch.FloatTensor = None 353 | ): 354 | """ 355 | :param command: batch_size x 1 356 | :param command: batch_size x conditioning_size 357 | :return: 358 | """ 359 | batch_size = command.shape[0] 360 | 361 | if conditioning is not None: 362 | command = torch.cat([command, conditioning], dim=1) 363 | 364 | generated_parameters = [] 365 | 366 | # fully connected 367 | for i in range(self.num_layers): 368 | if i == 0: 369 | hn = self.input_hn 370 | tiling = self.in_tiling 371 | embeddings = self.in_embeddings 372 | elif i == self.num_layers - 1: 373 | hn = self.output_hn 374 | tiling = self.out_tiling 375 | embeddings = self.out_embeddings 376 | else: 377 | hn = self.hidden_hn 378 | tiling = self.hidden_tiling[i - 1] 379 | embeddings = self.hidden_embeddings[i - 1] 380 | # repeat embeddings across batch 381 | embeddings = embeddings[None].repeat( 382 | batch_size, 1, 1, 1 383 | ) # batch_size x tiles_in x tiles_out x z_dim 384 | # repeat command across tiles 385 | r_command = command[:, None, None, :].repeat(1, tiling[0], tiling[1], 1) 386 | embeddings = embeddings.view(-1, embeddings.shape[-1]) 387 | r_command = r_command.view(-1, r_command.shape[-1]) 388 | w, b = hn(embeddings, r_command) 389 | if self.scale_layer_out: 390 | w = ( 391 | w 392 | * self.scale_parameter 393 | / torch.sqrt(torch.tensor([self.layer_sizes[i]]).float()).to( 394 | w.device 395 | ) 396 | ) 397 | b = ( 398 | b 399 | * self.scale_parameter 400 | / torch.sqrt(torch.tensor([self.layer_sizes[i]]).float()).to( 401 | b.device 402 | ) 403 | ) 404 | 405 | w = w.view( 406 | batch_size, tiling[0], tiling[1], hn.out_size_w[0], hn.out_size_w[1] 407 | ).permute(0, 2, 3, 1, 4) 408 | w = w.reshape( 409 | batch_size, tiling[1] * hn.out_size_w[0], tiling[0] * hn.out_size_w[1] 410 | ) 411 | b = b.reshape(batch_size, tiling[0], tiling[1], hn.out_size_w[0]).mean( 412 | dim=1 413 | ) 414 | b = b.view(batch_size, tiling[1] * hn.out_size_w[0]) 415 | generated_parameters.extend([w, b]) 416 | 417 | flat_parameters = [p.view(p.shape[0], -1) for p in generated_parameters] 418 | flat_parameters = torch.cat(flat_parameters, dim=1) 419 | 420 | generated_parameters = flat_parameters 421 | return generated_parameters 422 | 423 | 424 | class PSSVF_linear(nn.Module): 425 | def __init__(self, parameter_space_dim, hidden_sizes, activation): 426 | super().__init__() 427 | self.v_net = mlp([parameter_space_dim] + list(hidden_sizes) + [1], activation) 428 | 429 | def forward(self, parameters): 430 | return torch.squeeze(self.v_net(parameters), -1) 431 | 432 | 433 | class Command(nn.Module): 434 | def __init__(self): 435 | super().__init__() 436 | self.command = torch.nn.Parameter((torch.as_tensor(0.0).float())) 437 | 438 | 439 | class PSSVF(nn.Module): 440 | def __init__( 441 | self, obs_dim, num_probing_states, parameter_space_dim, hidden_sizes, activation 442 | ): 443 | super().__init__() 444 | 445 | self.probing_states = nn.ParameterList( 446 | [nn.Parameter(torch.rand([obs_dim])) for _ in range(num_probing_states)] 447 | ) 448 | self.v_net = mlp([parameter_space_dim] + list(hidden_sizes) + [1], activation) 449 | 450 | def forward(self, parameters, use_virtual_module=True, virtual_module=None): 451 | prob_sates = torch.stack( 452 | [ 453 | torch.nn.utils.parameters_to_vector(state) 454 | for state in self.probing_states 455 | ] 456 | ) 457 | actions = ( 458 | virtual_module.forward(prob_sates, parameters) 459 | .transpose(0, 1) 460 | .reshape(parameters.shape[0], -1) 461 | ) 462 | return torch.squeeze(self.v_net(actions), -1) 463 | 464 | 465 | class MLPActorCritic(nn.Module): 466 | def __init__( 467 | self, 468 | observation_space, 469 | action_space, 470 | n_probing_states, 471 | hidden_sizes_actor, 472 | activation, 473 | hidden_sizes_critic, 474 | device, 475 | critic, 476 | deterministic_actor, 477 | ): 478 | super().__init__() 479 | 480 | self.device = device 481 | self.deterministic_actor = deterministic_actor 482 | obs_dim = observation_space.shape[0] 483 | if isinstance(action_space, Box): 484 | self.act_dim = action_space.shape[0] 485 | act_limit = action_space.high[0] 486 | self.act_limit = act_limit 487 | elif isinstance(action_space, Discrete): 488 | self.act_dim = action_space.n 489 | 490 | self.pi = MLPActor( 491 | observation_space, 492 | action_space, 493 | hidden_sizes_actor, 494 | activation, 495 | device, 496 | deterministic_actor, 497 | ).to(device=device) 498 | 499 | named_params = self.pi.named_parameters() 500 | self.names, params = zip(*named_params) 501 | self.shapes = [param.shape for param in params] 502 | self.flat_shapes = [torch.flatten(param).shape for param in params] 503 | 504 | if critic: 505 | if isinstance(action_space, Box) and not deterministic_actor: 506 | # mean and sd of gaussian 507 | self.parameters_dim = n_probing_states * self.act_dim * 2 508 | else: 509 | self.parameters_dim = n_probing_states * self.act_dim 510 | 511 | self.v = PSSVF( 512 | obs_dim, 513 | n_probing_states, 514 | self.parameters_dim, 515 | hidden_sizes_critic, 516 | nn.ReLU, 517 | ).to(device=device) 518 | 519 | def act(self, obs, params, virtual_module=None): 520 | with torch.no_grad(): 521 | a = virtual_module.forward(obs.unsqueeze(0), params) 522 | return a.to(device="cpu").numpy() 523 | 524 | 525 | class DeterministicActorFunc(nn.Module): 526 | def __init__(self, act_limit, names, shapes, flat_shapes): 527 | super().__init__() 528 | self.act_limit = act_limit 529 | self.names = names 530 | self.shapes = shapes 531 | self.flat_shapes = flat_shapes 532 | 533 | def forward(self, x, params): 534 | temp = 0 535 | for idx, name in enumerate(self.names): 536 | if "weight" in name: 537 | weight = ( 538 | params[temp : temp + self.flat_shapes[idx][0]] 539 | .unsqueeze(0) 540 | .reshape(self.shapes[idx]) 541 | ) # .reshape([0:self.shapes[idx][0], 0:self.shapes[idx][1]]) 542 | temp += self.flat_shapes[idx][0] 543 | elif "bias" in name: 544 | bias = params[temp : temp + self.flat_shapes[idx][0]].reshape( 545 | self.shapes[idx] 546 | ) 547 | temp += self.flat_shapes[idx][0] 548 | else: 549 | raise ValueError 550 | if "bias" in name: 551 | x = F.linear(x, weight, bias) 552 | x = F.tanh(x) 553 | 554 | x = x * self.act_limit 555 | return x 556 | 557 | def get_probing_action(self, obs): 558 | return torch.tanh(self.pi_net(obs)) * self.act_limit 559 | 560 | 561 | # class Generator_lin(nn.Module): 562 | # def __init__(self, parameter_space_dim, hidden_sizes, activation, scale_parameter, device): 563 | # super().__init__() 564 | # self.scale_parameter = scale_parameter 565 | # self.device = device 566 | # self.parameter_dim = parameter_space_dim 567 | # self.v_net = mlp([1] + list(hidden_sizes) + [parameter_space_dim], activation) 568 | # 569 | # def forward(self, command, noise=None, evaluator=None, return_all=False, param=None): 570 | # parameters = torch.squeeze(self.v_net(command), -1) 571 | # parameters = parameters / self.scale_parameter 572 | # 573 | # if evaluator is None: 574 | # 575 | # if noise is not None: 576 | # pi = self._distribution(parameters, std=torch.as_tensor(noise).float().to(self.device)) 577 | # 578 | # dist = Normal(torch.zeros(self.parameter_dim), scale=1) 579 | # delta = dist.sample().to(device=self.device, non_blocking=True).detach() 580 | # parameters = parameters + noise * delta 581 | # 582 | # if param is not None: 583 | # logp_a = self._log_prob_from_distribution(pi, param) 584 | # else: 585 | # #print(pi, parameters) 586 | # logp_a = self._log_prob_from_distribution(pi, parameters) 587 | # 588 | # return parameters, logp_a 589 | # else: 590 | # return parameters 591 | # 592 | # else: 593 | # value = evaluator(parameters) 594 | # if return_all == False: 595 | # return value 596 | # else: 597 | # return value, parameters 598 | # 599 | # def _distribution(self, parameters, std): 600 | # mu = parameters 601 | # return Normal(mu, std) 602 | # 603 | # def _log_prob_from_distribution(self, pi, act): 604 | # return pi.log_prob(act).sum(axis=-1) 605 | 606 | 607 | class Generator(nn.Module): 608 | def __init__( 609 | self, 610 | pi, 611 | parameter_dim, 612 | hidden_sizes, 613 | use_hyper, 614 | use_parallel, 615 | hid_size_w, 616 | hid_size_b, 617 | out_size_w, 618 | out_size_b, 619 | device, 620 | policy_neurons=None, 621 | scale_layer_out=False, 622 | scale_parameter=None, 623 | ): 624 | super().__init__() 625 | 626 | self.device = device 627 | self.parameter_dim = parameter_dim 628 | self.use_hyper = use_hyper 629 | self.scale_parameter = scale_parameter 630 | self.scale_layer_out = scale_layer_out 631 | 632 | print( 633 | "scale layer by layer:", 634 | self.scale_layer_out, 635 | "otherwise scale everything with scale", 636 | self.scale_parameter, 637 | ) 638 | 639 | if policy_neurons is None: 640 | pi_shapes = [p.shape for p in pi.parameters()] 641 | policy_neurons = [ 642 | pi_shapes[0][1], 643 | pi_shapes[2][1], 644 | pi_shapes[4][1], 645 | pi_shapes[4][0], 646 | ] 647 | 648 | self.encoder = get_hypernetwork_mlp_generator( 649 | policy_neurons, 650 | hidden_sizes, 651 | embedding_dim=hid_size_w, 652 | features_per_embedding=out_size_w[0], 653 | scale_layer_out=scale_layer_out, 654 | scale_parameter=scale_parameter, 655 | ).to(device) 656 | 657 | def forward( 658 | self, 659 | tot_reward, 660 | noise=None, 661 | use_virtual_module=True, 662 | evaluator=None, 663 | virtual_module=None, 664 | return_all=False, 665 | param=None, 666 | ): 667 | parameters = self.encoder(tot_reward) 668 | 669 | # print("shape", parameters.shape) 670 | if evaluator is None: 671 | if noise is not None: 672 | dist = Normal(torch.zeros(self.parameter_dim), scale=1) 673 | delta = dist.sample().to(device=self.device, non_blocking=True).detach() 674 | parameters = parameters + noise * delta 675 | logp_a = None 676 | 677 | return parameters, logp_a 678 | else: 679 | return parameters 680 | 681 | else: 682 | value = evaluator( 683 | parameters, 684 | use_virtual_module=use_virtual_module, 685 | virtual_module=virtual_module, 686 | ) 687 | if return_all == False: 688 | return value 689 | else: 690 | return value, parameters 691 | 692 | def _distribution(self, parameters, std): 693 | mu = parameters 694 | return Normal(mu, std) 695 | 696 | def _log_prob_from_distribution(self, pi, act): 697 | return pi.log_prob(act).sum(axis=-1) 698 | 699 | 700 | class MLPStochasticCategoricalActor(nn.Module): 701 | def __init__(self, obs_dim, act_dim, hidden_sizes, activation): 702 | super().__init__() 703 | 704 | pi_sizes = [obs_dim] + list(hidden_sizes) + [act_dim] 705 | self.pi_net = mlp(pi_sizes, activation) 706 | 707 | def _distribution(self, obs): 708 | logits = self.pi_net(obs) 709 | return Categorical(logits=logits) 710 | 711 | def _log_prob_from_distribution(self, pi, act): 712 | return pi.log_prob(act) 713 | 714 | def forward(self, obs): 715 | pi = self._distribution(obs) 716 | return pi 717 | 718 | def get_probing_action(self, obs): 719 | pi = self._distribution(obs) 720 | return pi.probs 721 | 722 | 723 | class MLPDeterministicCategoricalActor(nn.Module): 724 | def __init__(self, obs_dim, act_dim, hidden_sizes, activation): 725 | super().__init__() 726 | pi_sizes = [obs_dim] + list(hidden_sizes) + [act_dim] 727 | self.pi_net = mlp(pi_sizes, activation) 728 | self.softmax = nn.Softmax() 729 | 730 | def forward(self, obs): 731 | logits = self.pi_net(obs) 732 | out = self.softmax(logits) 733 | a = out.argmax() 734 | return a 735 | 736 | def get_probing_action(self, obs): 737 | logits = self.pi_net(obs) 738 | out = self.softmax(logits) 739 | a = out.argmax() 740 | return a 741 | 742 | 743 | class MLPGaussianActor(nn.Module): 744 | def __init__(self, obs_dim, act_dim, hidden_sizes, activation, act_limit): 745 | super().__init__() 746 | pi_sizes = [obs_dim] + list(hidden_sizes) + [act_dim] 747 | self.pi_net = mlp(pi_sizes, activation) 748 | 749 | log_std = -0.5 * np.ones(act_dim, dtype=np.float32) 750 | self.log_std = torch.nn.Parameter(torch.as_tensor(log_std)) 751 | self.act_limit = act_limit 752 | 753 | def _distribution(self, obs): 754 | mu = self.pi_net(obs) 755 | std = torch.exp(self.log_std) 756 | return Normal(mu, std) 757 | 758 | def _log_prob_from_distribution(self, pi, act): 759 | return pi.log_prob(act).sum(axis=-1) 760 | 761 | def forward(self, obs): 762 | 763 | pi = self._distribution(obs) 764 | return pi 765 | 766 | def get_probing_action(self, obs): 767 | pi = self._distribution(obs) 768 | return torch.cat((pi.mean, pi.scale)) 769 | 770 | 771 | class DeterministicActor(nn.Module): 772 | def __init__(self, obs_dim, act_dim, hidden_sizes, activation, act_limit): 773 | super().__init__() 774 | pi_sizes = [obs_dim] + list(hidden_sizes) + [act_dim] 775 | self.pi_net = mlp(pi_sizes, activation, nn.Tanh) 776 | self.act_limit = act_limit 777 | 778 | def forward(self, obs): 779 | return self.pi_net(obs) 780 | 781 | def get_probing_action(self, obs): 782 | return torch.tanh(self.pi_net(obs)) * self.act_limit 783 | 784 | 785 | class MLPActor(nn.Module): 786 | def __init__( 787 | self, 788 | observation_space, 789 | action_space, 790 | hidden_sizes_actor, 791 | activation, 792 | device, 793 | deterministic_actor, 794 | ): 795 | super().__init__() 796 | 797 | self.act_limit = None 798 | self.deterministic_actor = deterministic_actor 799 | obs_dim = observation_space.shape[0] 800 | if isinstance(action_space, Box): 801 | act_dim = action_space.shape[0] 802 | act_limit = action_space.high[0] 803 | self.act_limit = act_limit 804 | elif isinstance(action_space, Discrete): 805 | act_dim = action_space.n 806 | 807 | if isinstance(action_space, Box): 808 | if deterministic_actor: 809 | self.pi = DeterministicActor( 810 | obs_dim, act_dim, hidden_sizes_actor, activation, act_limit 811 | ).to(device=device) 812 | else: 813 | self.pi = MLPGaussianActor( 814 | obs_dim, act_dim, hidden_sizes_actor, activation, act_limit 815 | ).to(device=device) 816 | 817 | elif isinstance(action_space, Discrete): 818 | if deterministic_actor: 819 | self.pi = MLPDeterministicCategoricalActor( 820 | obs_dim, action_space.n, hidden_sizes_actor, activation 821 | ).to(device=device) 822 | else: 823 | self.pi = MLPStochasticCategoricalActor( 824 | obs_dim, action_space.n, hidden_sizes_actor, activation 825 | ).to(device=device) 826 | 827 | def act(self, obs): 828 | with torch.no_grad(): 829 | if self.deterministic_actor: 830 | a = self.pi(obs) 831 | a = (torch.tanh(a) * self.act_limit).to(device="cpu").numpy() 832 | return a 833 | else: 834 | pi = self.pi(obs) 835 | a = pi.sample() 836 | if isinstance(self.pi, MLPGaussianActor): 837 | a = self.act_limit * torch.tanh(a).to(device="cpu").numpy() 838 | return a 839 | 840 | 841 | class Statistics(object): 842 | def __init__(self, obs_dim): 843 | super().__init__() 844 | 845 | self.total_ts = 0 846 | self.episode = 0 847 | self.len_episode = 0 848 | self.rew_shaped_eval = 0 849 | self.rew_eval = 0 850 | self.rewards = [] 851 | self.last_rewards = [] 852 | self.position = 0 853 | self.n = 0 854 | self.mean = torch.zeros(obs_dim) 855 | self.mean_diff = torch.zeros(obs_dim) 856 | self.std = torch.zeros(obs_dim) 857 | self.command = 0 858 | self.last_rewards_env = [] 859 | self.rewards_env = [] 860 | self.position_env = 0 861 | self.max_rew = -np.inf 862 | self.min_rew = np.inf 863 | self.sim_time = 0 864 | self.up_policy_time = 0 865 | self.up_v_time = 0 866 | self.total_time = 0 867 | self.gen_time = 0 868 | self.max_pred = -np.inf 869 | 870 | def push_obs(self, obs): 871 | self.n += 1.0 872 | last_mean = self.mean 873 | self.mean += (obs - self.mean) / self.n 874 | self.mean_diff += (obs - last_mean) * (obs - self.mean) 875 | var = self.mean_diff / (self.n - 1) if self.n > 1 else np.square(self.mean) 876 | self.std = np.sqrt(var) 877 | return 878 | 879 | def push_rew(self, rew): 880 | if len(self.last_rewards) < 20: 881 | self.last_rewards.append(rew) 882 | else: 883 | self.last_rewards[self.position] = rew 884 | self.position = (self.position + 1) % 20 885 | self.rewards.append(rew) 886 | 887 | def push_rew_env(self, rew): 888 | if len(self.last_rewards_env) < 20: 889 | self.last_rewards_env.append(rew) 890 | else: 891 | self.last_rewards_env[self.position_env] = rew 892 | self.position_env = (self.position_env + 1) % 20 893 | self.rewards_env.append(rew) 894 | 895 | def normalize(self, obs): 896 | return (obs - self.mean) / (self.std + 1e-8) 897 | 898 | 899 | class Buffer(object): 900 | def __init__(self, size_buffer, scale=1.0): 901 | self.history = [] 902 | self.size_buffer = size_buffer 903 | self.scale = scale 904 | 905 | def sample_replay(self, batch_size, weighted_sampling=False): 906 | 907 | if weighted_sampling: 908 | self.weights = list( 909 | np.reciprocal(np.arange(1, len(self.history) + 1, dtype=float)) 910 | ) 911 | self.weights.reverse() 912 | self.weights = np.array(self.weights) ** self.scale 913 | self.weights = list(self.weights) 914 | sampled_hist = random.choices( 915 | self.history, 916 | weights=self.weights, 917 | k=min(int(batch_size), len(self.history)), 918 | ) 919 | else: 920 | sampled_hist = random.sample( 921 | self.history, min(int(batch_size), len(self.history)) 922 | ) 923 | if len(self.history) > self.size_buffer: 924 | self.history.pop(0) 925 | return sampled_hist 926 | 927 | 928 | class BufferTD(object): 929 | def __init__(self, capacity): 930 | self.history = [] 931 | self.capacity = capacity 932 | self.position = 0 933 | 934 | def push(self, transition): 935 | if len(self.history) < self.capacity: 936 | self.history.append(transition) 937 | else: 938 | self.history[self.position] = transition 939 | self.position = (self.position + 1) % self.capacity 940 | 941 | def sample_replay_td(self, batch_size): 942 | 943 | sampled_trans = random.choices(self.history, k=int(batch_size)) 944 | return sampled_trans 945 | 946 | 947 | def grad_norm(parameters): 948 | # Compute the norm of the gradient 949 | if isinstance(parameters, torch.Tensor): 950 | parameters = [parameters] 951 | parameters = list(filter(lambda p: p.grad is not None, parameters)) 952 | norm_type = float(2) 953 | total_norm = 0 954 | for p in parameters: 955 | param_norm = p.grad.data.norm(norm_type) 956 | total_norm += param_norm.item() ** norm_type 957 | total_norm = total_norm ** (1.0 / norm_type) 958 | return total_norm 959 | 960 | 961 | def norm(parameters): 962 | # Compute the norm of the weights of a model 963 | if isinstance(parameters, torch.Tensor): 964 | parameters = [parameters] 965 | parameters = list(filter(lambda p: p.grad is not None, parameters)) 966 | norm_type = float(2) 967 | total_norm = 0 968 | for p in parameters: 969 | param_norm = p.data.norm(norm_type) 970 | total_norm += param_norm.item() ** norm_type 971 | total_norm = total_norm ** (1.0 / norm_type) 972 | return total_norm 973 | -------------------------------------------------------------------------------- /gogepo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import gym 4 | import core_gogepo as core 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | from torch.distributions.normal import Normal 8 | import matplotlib.pyplot as plt 9 | import time 10 | import argparse 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument( 14 | "--env_name", 15 | default="Swimmer-v3", 16 | choices=[ 17 | "Swimmer-v3", 18 | "Hopper-v3", 19 | "InvertedPendulum-v2", 20 | "MountainCarContinuous-v0", 21 | ], 22 | type=str, 23 | required=False, 24 | ) 25 | parser.add_argument("--verbose", default=0, type=int, required=False) 26 | parser.add_argument("--show_plots", default=0, type=int, required=False) 27 | parser.add_argument("--use_gpu", default=1, type=int, required=False) 28 | parser.add_argument("--seed", default=1234, type=int, required=False) 29 | args = parser.parse_args() 30 | 31 | verbose = args.verbose 32 | show_plots = args.show_plots 33 | 34 | # Default hyperparameters 35 | config = dict( 36 | algo="hpg_r", 37 | size_buffer=10000, 38 | size_buffer_command=64, 39 | max_episodes=1000000000, 40 | max_timesteps=3000000, 41 | run=1, # id 42 | ts_evaluation=10000, 43 | episodes_per_epoch=1, 44 | start_steps=0, 45 | seed=args.seed, 46 | # Env 47 | env_name=args.env_name, 48 | survival_bonus=False, 49 | # Policy 50 | neurons_policy=(256, 256), 51 | noise_policy=0.1, # std of distribution generating the noise for the perturbed policy 52 | observation_normalization=True, 53 | deterministic_actor=True, 54 | # Command evaluation 55 | ts_evaluation_generator=100000, 56 | rew_min=0, 57 | rew_max=3000, 58 | n_steps=20, 59 | # Command optimization 60 | normalize_command=True, 61 | noise_command=0, # max, heu 62 | drive_parameter=20, # max, heu 63 | update_command="sampled", # max, sampled, combine 64 | # vf 65 | neurons_vf=(256, 256), 66 | learning_rate_vf=5e-3, 67 | vf_iters=5, 68 | n_probing_states=200, 69 | # Generator 70 | use_hyper=True, 71 | gen_iters=20, 72 | neurons_generator=(256, 256), 73 | batch_size=16, 74 | learning_rate_gen=2e-6, 75 | z_dim_w=8, 76 | z_dim_b=8, 77 | out_size_w=[16, 16], 78 | out_size_b=16, 79 | reset_command=100, 80 | weighted_sampling_command=False, 81 | use_bound=True, 82 | use_virtual_class=True, 83 | update_every_ts=False, 84 | update_every=100, 85 | weighted_sampling=True, 86 | scale=1.1, 87 | # IS 88 | use_is=False, 89 | learning_rate_command=1e-3, # is 90 | batch_size_command=16, # is 91 | updates_command=5, # is 92 | delta=0.5, 93 | use_gradient=False, 94 | use_bh=True, 95 | use_parallel=True, 96 | scale_layer_out=True, 97 | scale_parameter=2, 98 | use_max_pred=False, 99 | noise_command_up=0, 100 | drift_command_up=0, 101 | save=False, 102 | save_model_every=100000000, 103 | ) 104 | 105 | 106 | if config["env_name"] == "CartPole-v1": 107 | config.update({"rew_min": 0}, allow_val_change=True) 108 | config.update({"rew_max": 500}, allow_val_change=True) 109 | elif config["env_name"] == "Swimmer-v3": 110 | config.update({"rew_min": -100}, allow_val_change=True) 111 | config.update({"rew_max": 365}, allow_val_change=True) 112 | elif config["env_name"] == "InvertedPendulum-v2": 113 | config.update({"rew_min": 0}, allow_val_change=True) 114 | config.update({"rew_max": 1000}, allow_val_change=True) 115 | config.update({"ts_evaluation_generator": 10000}, allow_val_change=True) 116 | config.update({"max_timesteps": 100000}, allow_val_change=True) 117 | config.update({"ts_evaluation": 1000}, allow_val_change=True) 118 | 119 | 120 | elif config["env_name"] == "Walker2d-v3": 121 | config.update({"rew_min": -100}, allow_val_change=True) 122 | config.update({"rew_max": 3000}, allow_val_change=True) 123 | elif config["env_name"] == "HalfCheetah-v3": 124 | config.update({"rew_min": -100}, allow_val_change=True) 125 | config.update({"rew_max": 4000}, allow_val_change=True) 126 | elif config["env_name"] == "Hopper-v3": 127 | config.update({"rew_min": -100}, allow_val_change=True) 128 | config.update({"rew_max": 3000}, allow_val_change=True) 129 | elif config["env_name"] == "InvertedDoublePendulum-v2": 130 | config.update({"rew_min": 0}, allow_val_change=True) 131 | config.update({"rew_max": 10000}, allow_val_change=True) 132 | config.update({"ts_evaluation_generator": 10000}, allow_val_change=True) 133 | config.update({"max_timesteps": 100000}, allow_val_change=True) 134 | config.update({"ts_evaluation": 1000}, allow_val_change=True) 135 | 136 | elif config["env_name"] == "MountainCarContinuous-v0": 137 | config.update({"rew_min": -100}, allow_val_change=True) 138 | config.update({"rew_max": 100}, allow_val_change=True) 139 | config.update({"ts_evaluation_generator": 10000}, allow_val_change=True) 140 | config.update({"max_timesteps": 100000}, allow_val_change=True) 141 | config.update({"ts_evaluation": 1000}, allow_val_change=True) 142 | 143 | 144 | if config["env_name"] in [ 145 | "MountainCarContinuous-v0", 146 | "InvertedPendulum-v2", 147 | "Reacher-v2", 148 | ]: 149 | config.update( 150 | { 151 | "ts_evaluation": 1000, 152 | "max_timesteps": 100000, 153 | }, 154 | allow_val_change=True, 155 | ) 156 | config.update( 157 | { 158 | "size_buffer_command": config["batch_size_command"], 159 | }, 160 | allow_val_change=True, 161 | ) 162 | 163 | # Use GPU or CPU 164 | if args.use_gpu: 165 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 166 | else: 167 | device = torch.device("cpu") 168 | 169 | torch.manual_seed(config["seed"]) 170 | np.random.seed(config["seed"]) 171 | 172 | # Create env 173 | env = gym.make(config["env_name"]) 174 | env_test = gym.make(config["env_name"]) 175 | 176 | # Create replay buffer, policy, vf 177 | buffer = core.Buffer(config["size_buffer"], scale=config["scale"]) 178 | command_buffer = core.Buffer(config["size_buffer_command"]) 179 | statistics = core.Statistics(env.observation_space.shape) 180 | ac = core.MLPActorCritic( 181 | env.observation_space, 182 | env.action_space, 183 | config["n_probing_states"], 184 | hidden_sizes_actor=tuple(config["neurons_policy"]), 185 | activation=nn.Tanh, 186 | hidden_sizes_critic=tuple(config["neurons_vf"]), 187 | device=device, 188 | critic=True, 189 | deterministic_actor=config["deterministic_actor"], 190 | ).to(device) 191 | command = core.Command() 192 | 193 | params_dim = len(nn.utils.parameters_to_vector(list(ac.pi.parameters()))) 194 | 195 | sizes = [param.shape for param in list(ac.parameters())] 196 | 197 | generator = core.Generator( 198 | ac.pi, 199 | params_dim, 200 | hidden_sizes=tuple(config["neurons_generator"]), 201 | use_hyper=config["use_hyper"], 202 | use_parallel=config["use_parallel"], 203 | hid_size_w=config["z_dim_w"], 204 | hid_size_b=config["z_dim_b"], 205 | out_size_w=config["out_size_w"], 206 | out_size_b=config["out_size_b"], 207 | device=device, 208 | scale_layer_out=config["scale_layer_out"], 209 | scale_parameter=config["scale_parameter"], 210 | ).to(device=device) 211 | 212 | if config["use_virtual_class"]: 213 | virtual_mlp = core.VirtualMLPPolicy( 214 | layer_sizes=[env.observation_space.shape[0]] 215 | + list(tuple(config["neurons_policy"])) 216 | + [env.action_space.shape[0]], 217 | act_lim=env.action_space.high[0], 218 | ) 219 | 220 | if verbose: 221 | print(generator.encoder) 222 | 223 | print(ac.pi) 224 | print( 225 | "Number of policy params:", 226 | len(nn.utils.parameters_to_vector(list(ac.pi.parameters()))), 227 | ) 228 | print( 229 | "Number of vf params:", 230 | len(nn.utils.parameters_to_vector(list(ac.v.parameters()))), 231 | ) 232 | print( 233 | "Number of generator params:", 234 | len(nn.utils.parameters_to_vector(list(generator.parameters()))), 235 | ) 236 | 237 | model_params = ( 238 | nn.utils.parameters_to_vector(list(ac.pi.parameters())).detach().to("cpu") 239 | ) 240 | q = torch.tensor([0.25, 0.5, 0.75]) 241 | print( 242 | "init quant 0.25, 0.5, 0.75", 243 | torch.quantile(model_params, q, dim=0, keepdim=True), 244 | ) 245 | 246 | for p in ac.pi.parameters(): 247 | print("max", torch.max(p)) 248 | 249 | # Setup optimizer 250 | optimize_generator = optim.Adam(generator.parameters(), lr=config["learning_rate_gen"]) 251 | optimize_vf = optim.Adam(ac.v.parameters(), lr=config["learning_rate_vf"]) 252 | 253 | 254 | def compute_vf_loss(progs, rew): 255 | q = ac.v( 256 | progs, 257 | use_virtual_module=config["use_virtual_class"], 258 | virtual_module=virtual_mlp, 259 | ) 260 | statistics.max_pred = max(statistics.max_pred, torch.max(q).detach().item()) 261 | loss = ((q - rew) ** 2).mean() 262 | return loss 263 | 264 | 265 | def compute_generator_loss(rew): 266 | loss = ( 267 | ( 268 | generator( 269 | rew / (config["rew_max"] - config["rew_min"]), 270 | use_virtual_module=config["use_virtual_class"], 271 | evaluator=ac.v, 272 | virtual_module=virtual_mlp, 273 | ) 274 | - rew.squeeze() 275 | ) 276 | ** 2 277 | ).mean() 278 | return loss 279 | 280 | 281 | def perturb_policy(policy): 282 | dist = Normal( 283 | torch.zeros(len(torch.nn.utils.parameters_to_vector(policy.parameters()))), 284 | scale=1, 285 | ) 286 | delta = dist.sample().to(device=device, non_blocking=True).detach() 287 | 288 | # Perturbe policy parameters 289 | params = torch.nn.utils.parameters_to_vector(policy.parameters()).detach() 290 | perturbed_params = params + config["noise_policy"] * delta 291 | 292 | # Copy perturbed parameters into a new policy 293 | perturbed_policy = core.MLPActorCritic( 294 | env.observation_space, 295 | env.action_space, 296 | config["n_probing_states"], 297 | hidden_sizes_actor=tuple(config["neurons_policy"]), 298 | activation=nn.Tanh, 299 | hidden_sizes_critic=tuple(config["neurons_vf"]), 300 | device=device, 301 | critic=False, 302 | deterministic_actor=config["deterministic_actor"], 303 | ).to(device) 304 | 305 | torch.nn.utils.vector_to_parameters(perturbed_params, perturbed_policy.parameters()) 306 | 307 | return perturbed_policy 308 | 309 | 310 | def evaluate_behavior_offline(rew_min, rew_max, n_steps): 311 | step = (rew_max - rew_min) / n_steps 312 | outputs = [] 313 | 314 | for idx in range(n_steps + 1): 315 | rew = rew_min + idx * step 316 | 317 | with torch.no_grad(): 318 | out_rew, parameters = generator( 319 | torch.tensor([rew / (config["rew_max"] - config["rew_min"])]) 320 | .float() 321 | .to(device) 322 | .unsqueeze(0), 323 | use_virtual_module=config["use_virtual_class"], 324 | evaluator=ac.v, 325 | virtual_module=virtual_mlp, 326 | return_all=True, 327 | ) 328 | 329 | outputs.append((out_rew.squeeze().cpu().numpy(), rew)) 330 | 331 | return outputs 332 | 333 | 334 | def perturb_command(command): 335 | if config["use_is"]: 336 | if config["use_gradient"]: 337 | perturbed_rew = command.cpu().item() * ( 338 | config["rew_max"] - config["rew_min"] 339 | ) 340 | 341 | else: 342 | perturbed_rew = command.cpu().item() 343 | 344 | else: 345 | dist = Normal(0, scale=1) 346 | delta = dist.sample().item() 347 | # Perturbe policy parameters 348 | perturbed_rew = ( 349 | command + config["drive_parameter"] + config["noise_command"] * delta 350 | ) 351 | 352 | return perturbed_rew 353 | 354 | 355 | def update_command(): 356 | if config["use_max_pred"]: 357 | steps = torch.rand(50).float().unsqueeze(1).to(device) 358 | steps = ( 359 | config["rew_min"] + steps * (statistics.max_rew - config["rew_min"]) * 1.2 360 | ) 361 | with torch.no_grad(): 362 | values = generator( 363 | steps / (config["rew_max"] - config["rew_min"]), 364 | use_virtual_module=config["use_virtual_class"], 365 | evaluator=ac.v, 366 | virtual_module=virtual_mlp, 367 | ) 368 | command.command = torch.nn.Parameter(torch.max(values)) 369 | else: 370 | command.command = torch.nn.Parameter(torch.as_tensor(statistics.max_rew)) 371 | return 372 | 373 | 374 | def update(): 375 | # Update evaluator 376 | start_time = time.perf_counter() 377 | 378 | for idx in range(1, config["vf_iters"]): 379 | # Sample batch 380 | hist = buffer.sample_replay( 381 | config["batch_size"], weighted_sampling=config["weighted_sampling"] 382 | ) 383 | prog, rew, rew_gen = zip(*hist) 384 | rew = ( 385 | torch.from_numpy(np.asarray(rew)) 386 | .float() 387 | .to(device=device, non_blocking=True) 388 | .detach() 389 | ) 390 | prog = torch.stack(prog) 391 | optimize_vf.zero_grad() 392 | loss_vf = compute_vf_loss(prog, rew) 393 | loss_vf.backward() 394 | optimize_vf.step() 395 | 396 | statistics.up_v_time += time.perf_counter() - start_time 397 | start_time = time.perf_counter() 398 | 399 | for p in ac.v.parameters(): 400 | p.requires_grad = False 401 | for _ in range(1, config["gen_iters"]): 402 | if config["update_command"] == "generated": 403 | t = config["drive_parameter"] + config["noise_command"] 404 | rew_gen = ( 405 | statistics.min_rew 406 | + t 407 | + torch.rand(config["batch_size"]).float().unsqueeze(1).to(device) 408 | * (statistics.max_rew - statistics.min_rew) 409 | ) 410 | 411 | optimize_generator.zero_grad() 412 | loss_gen = compute_generator_loss(rew_gen.float().to(device)) 413 | loss_gen.backward() 414 | optimize_generator.step() 415 | elif config["update_command"] == "sampled": 416 | hist = buffer.sample_replay( 417 | config["batch_size"], weighted_sampling=config["weighted_sampling"] 418 | ) 419 | _, _, rew_gen = zip(*hist) 420 | 421 | rew_gen = torch.stack(rew_gen) 422 | rew_gen += ( 423 | config["drift_command_up"] 424 | + torch.rand(rew_gen.shape[0]).float().unsqueeze(1) 425 | * config["noise_command_up"] 426 | ) 427 | 428 | optimize_generator.zero_grad() 429 | loss_gen = compute_generator_loss(rew_gen.float().to(device)) 430 | loss_gen.backward() 431 | optimize_generator.step() 432 | else: 433 | raise ValueError 434 | 435 | for p in ac.v.parameters(): 436 | p.requires_grad = True 437 | 438 | statistics.up_policy_time += time.perf_counter() - start_time 439 | 440 | log_dict = { 441 | "loss_gen": loss_gen.item(), 442 | "loss_vf": loss_vf.item(), 443 | "grads_norm_generator": core.grad_norm(generator.parameters()), 444 | "norm_generator": core.norm(generator.parameters()), 445 | "norm_pvf": core.norm(ac.v.parameters()), 446 | "grads_norm_pvf": core.grad_norm(ac.v.parameters()), 447 | "norm_prob_states": core.norm(ac.v.probing_states.parameters()), 448 | "grads_norm_prob_states": core.grad_norm(ac.v.probing_states.parameters()), 449 | "max_rew": statistics.max_rew, 450 | } 451 | if verbose: 452 | print(log_dict) 453 | 454 | return 455 | 456 | 457 | def evaluate(policy_params, log=True, n_eval=10): 458 | rew_evals = [] 459 | with torch.no_grad(): 460 | for _ in range(n_eval): 461 | 462 | # Simulate a trajectory and compute the total reward 463 | done = False 464 | obs = env_test.reset() 465 | rew_eval = 0 466 | while not done: 467 | obs = torch.as_tensor(obs, dtype=torch.float32) 468 | if config["observation_normalization"] and statistics.episode > 0: 469 | obs = statistics.normalize(obs) 470 | 471 | with torch.no_grad(): 472 | action = ac.act( 473 | obs.to(device), policy_params, virtual_module=virtual_mlp 474 | ) 475 | # action = ac.act(obs.to(device), policy_params) 476 | obs_new, r, done, _ = env_test.step(action[0]) 477 | 478 | rew_eval += r 479 | obs = obs_new 480 | 481 | rew_evals.append(rew_eval) 482 | if log: 483 | statistics.rew_eval = np.mean(rew_evals) 484 | statistics.push_rew(np.mean(rew_evals)) 485 | # Log results 486 | if log: 487 | print( 488 | "Ts", 489 | statistics.total_ts, 490 | "Ep", 491 | statistics.episode, 492 | "rew_eval", 493 | statistics.rew_eval, 494 | ) 495 | print( 496 | "time_sim", 497 | statistics.sim_time, 498 | "time_gen", 499 | statistics.gen_time, 500 | "time_up_pi", 501 | statistics.up_policy_time, 502 | "time_up_v", 503 | statistics.up_v_time, 504 | "total_time", 505 | statistics.total_time, 506 | ) 507 | return np.mean(rew_evals) 508 | 509 | 510 | def simulate_policy(perturbed_params): 511 | # Simulate a trajectory and compute the total reward 512 | done = False 513 | obs = env.reset() 514 | rew = 0 515 | while not done: 516 | obs = torch.as_tensor(obs, dtype=torch.float32) 517 | if config["observation_normalization"]: 518 | statistics.push_obs(obs) 519 | if statistics.episode > 0: 520 | obs = statistics.normalize(obs) 521 | 522 | with torch.no_grad(): 523 | action = ac.act( 524 | obs.to(device), perturbed_params, virtual_module=virtual_mlp 525 | ) 526 | obs_new, r, done, _ = env.step(action[0]) 527 | if not config["survival_bonus"]: 528 | if ( 529 | config["env_name"] == "Hopper-v3" 530 | or config["env_name"] == "Ant-v3" 531 | or config["env_name"] == "Walker2d-v3" 532 | ): 533 | rew += r - 1 534 | elif config["env_name"] == "Humanoid-v3": 535 | rew += r - 5 536 | else: 537 | rew += r 538 | else: 539 | rew += r 540 | 541 | statistics.total_ts += 1 542 | 543 | # Evaluate current policy 544 | if ( 545 | statistics.total_ts % config["ts_evaluation"] == 0 546 | and statistics.episode > 0 547 | ): 548 | with torch.no_grad(): 549 | if config["use_max_pred"]: 550 | parameters = generator( 551 | torch.tensor( 552 | [ 553 | command.command.cpu().item() 554 | / (config["rew_max"] - config["rew_min"]) 555 | ] 556 | ) 557 | .float() 558 | .to(device) 559 | .unsqueeze(0) 560 | ) 561 | 562 | else: 563 | parameters = generator( 564 | torch.tensor( 565 | [ 566 | statistics.max_rew 567 | / (config["rew_max"] - config["rew_min"]) 568 | ] 569 | ) 570 | .float() 571 | .to(device) 572 | .unsqueeze(0) 573 | ) 574 | parameters = parameters.squeeze() 575 | evaluate(parameters) 576 | 577 | # Update 578 | if ( 579 | statistics.total_ts > config["start_steps"] 580 | and config["update_every_ts"] 581 | and statistics.episode > 0 582 | ): 583 | if statistics.total_ts % config["update_every"] == 0: 584 | update() 585 | 586 | # save metrics 587 | model_states = {"ac": ac, "generator": generator, "statistics": statistics} 588 | 589 | if config["save"]: 590 | 591 | if statistics.total_ts % config["save_model_every"] == 0: 592 | torch.save( 593 | model_states, 594 | "data/model_" 595 | + str(config["seed"]) 596 | + str(config["env_name"]) 597 | + str(statistics.total_ts) 598 | + ".pth", 599 | ) 600 | 601 | print("saving model") 602 | 603 | if statistics.total_ts == 1000000: 604 | log_dict = { 605 | "rew_eval_1M": statistics.rew_eval, 606 | "average_reward_1M": np.mean(statistics.rewards), 607 | "average_last_rewards_1M": np.mean(statistics.last_rewards), 608 | } 609 | if verbose: 610 | print(log_dict) 611 | 612 | if statistics.total_ts % config["ts_evaluation_generator"] == 0: 613 | result = evaluate_behavior( 614 | config["rew_min"], config["rew_max"], config["n_steps"] 615 | ) 616 | if show_plots: 617 | y, x = zip(*result) 618 | fig, ax = plt.subplots(1, 1) 619 | ax.plot(x, y) 620 | ax.plot(x, x) 621 | plt.show() 622 | 623 | result = evaluate_behavior_offline( 624 | config["rew_min"], config["rew_max"], config["n_steps"] 625 | ) 626 | if show_plots: 627 | y, x = zip(*result) 628 | fig2, ax2 = plt.subplots(1, 1) 629 | ax2.plot(x, y) 630 | ax2.plot(x, x) 631 | plt.show() 632 | obs = obs_new 633 | return rew 634 | 635 | 636 | def evaluate_behavior(rew_min, rew_max, n_steps): 637 | step = (rew_max - rew_min) / n_steps 638 | outputs = [] 639 | 640 | for idx in range(n_steps + 1): 641 | rew = rew_min + idx * step 642 | 643 | with torch.no_grad(): 644 | parameters = generator( 645 | torch.tensor([rew / (config["rew_max"] - config["rew_min"])]) 646 | .float() 647 | .to(device) 648 | .unsqueeze(0) 649 | ) 650 | parameters = parameters.squeeze() 651 | out_rew = evaluate(parameters, log=False, n_eval=1) 652 | outputs.append((out_rew, rew)) 653 | 654 | return outputs 655 | 656 | 657 | def train(): 658 | start_time = time.perf_counter() 659 | 660 | # Choose command 661 | if statistics.episode > 0: 662 | perturbed_command = perturb_command(command.command) 663 | else: 664 | perturbed_command = 1 665 | 666 | # Generate policy and perturbe it 667 | with torch.no_grad(): 668 | perturbed_params, logp_a = generator( 669 | torch.tensor([perturbed_command / (config["rew_max"] - config["rew_min"])]) 670 | .unsqueeze(0) 671 | .float() 672 | .to(device), 673 | noise=config["noise_policy"], 674 | ) 675 | perturbed_params = perturbed_params.squeeze() 676 | 677 | mean_param = generator( 678 | torch.tensor([perturbed_command / (config["rew_max"] - config["rew_min"])]) 679 | .unsqueeze(0) 680 | .float() 681 | .to(device) 682 | ) 683 | mean_param = mean_param.squeeze().to("cpu") 684 | 685 | if statistics.episode == 0: 686 | torch.nn.utils.vector_to_parameters(mean_param, ac.pi.parameters()) 687 | 688 | if verbose: 689 | for p in ac.pi.parameters(): 690 | print("max gen", torch.max(p)) 691 | 692 | statistics.gen_time += time.perf_counter() - start_time 693 | start_time = time.perf_counter() 694 | 695 | # Simulate a trajectory and compute the total reward 696 | rew = simulate_policy(perturbed_params) 697 | statistics.max_rew = max(statistics.max_rew, rew) 698 | statistics.min_rew = min(statistics.min_rew, rew) 699 | 700 | # Store data in replay buffer 701 | buffer.history.append((perturbed_params, rew, torch.tensor([rew]).float())) 702 | command_buffer.history.append( 703 | (perturbed_params, torch.tensor([rew]).float(), logp_a, mean_param) 704 | ) 705 | if len(command_buffer.history) > command_buffer.size_buffer: 706 | command_buffer.history.pop(0) 707 | 708 | statistics.episode += 1 709 | 710 | statistics.sim_time += time.perf_counter() - start_time 711 | 712 | # Update 713 | if statistics.total_ts > config["start_steps"] and not config["update_every_ts"]: 714 | update() 715 | 716 | # Log results 717 | if statistics.episode % 50 == 0: 718 | print("Ts", statistics.total_ts, "Rew", rew) 719 | 720 | log_dict = { 721 | "rew": rew, 722 | "steps": statistics.total_ts, 723 | "episode": statistics.episode, 724 | "command": command.command.detach().item(), 725 | "perturbed_command": perturbed_command, 726 | "max_pred": statistics.max_pred, 727 | } 728 | if verbose: 729 | print(log_dict) 730 | 731 | statistics.push_rew_env(rew) 732 | 733 | # Update command 734 | if ( 735 | statistics.episode % config["episodes_per_epoch"] == 0 736 | and statistics.total_ts > config["start_steps"] 737 | ): 738 | update_command() 739 | return 740 | 741 | 742 | # Loop over episodes 743 | while ( 744 | statistics.total_ts < config["max_timesteps"] 745 | and statistics.episode < config["max_episodes"] 746 | ): 747 | start_time = time.perf_counter() 748 | train() 749 | statistics.total_time += time.perf_counter() - start_time 750 | 751 | if config["save"]: 752 | # save metrics 753 | model_states = { 754 | "ac": ac, 755 | "generator": generator, 756 | "statistics": statistics, 757 | "buffer": buffer, 758 | } 759 | 760 | torch.save( 761 | model_states, 762 | "data/final_model_" 763 | + str(config["seed"]) 764 | + str(config["env_name"]) 765 | + str(statistics.total_ts) 766 | + ".pth", 767 | ) 768 | 769 | print("saving final model") 770 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib 2 | numpy 3 | torch 4 | gym 5 | mujoco_py --------------------------------------------------------------------------------