├── README.md ├── embedding_networks.py ├── env_utils.py ├── eval_pdvf.py ├── figures ├── ablations_pdvf.png ├── pdvf_gif.png └── results_pdvf.png ├── myant ├── myant │ ├── __init__.py │ ├── envs │ │ ├── __init__.py │ │ └── myant.py │ └── utils.py └── setup.py ├── myspaceship ├── myspaceship │ ├── __init__.py │ └── envs │ │ ├── __init__.py │ │ └── myspaceship.py └── setup.py ├── myswimmer ├── myswimmer │ ├── __init__.py │ ├── envs │ │ ├── __init__.py │ │ └── myswimmer.py │ └── utils.py └── setup.py ├── pdvf_arguments.py ├── pdvf_networks.py ├── pdvf_storage.py ├── pdvf_utils.py ├── ppo ├── __init__.py ├── algo │ ├── __init__.py │ └── ppo.py ├── arguments.py ├── distributions.py ├── envs.py ├── evaluation.py ├── model.py ├── ppo_main.py ├── storage.py └── utils.py ├── requirements.txt ├── train_dynamics_embedding.py ├── train_pdvf.py ├── train_policy_embedding.py └── train_utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Policy-Dynamics Value Functions (PD-VF) 2 | 3 | This is source code for the paper 4 | 5 | [Fast Adaptation to New Environments via Policy-Dynamics Value Functions](https://arxiv.org/pdf/2007.02879) 6 | 7 | by Roberta Raileanu, Max Goldstein, Arthur Szlam, and Rob Fergus, 8 | 9 | accepted at ICML 2020. 10 | 11 | 12 | 13 | ## Citation 14 | If you use this code in your own work, please cite our paper: 15 | ``` 16 | @incollection{icml2020_3993, 17 | abstract = {Standard RL algorithms assume fixed environment dynamics and require a significant amount of interaction to adapt to new environments. We introduce Policy-Dynamics Value Functions (PD-VF), a novel approach for rapidly adapting to dynamics different from those previously seen in training. PD-VF explicitly estimates the cumulative reward in a space of policies and environments. An ensemble of conventional RL policies is used to gather experience on training environments, from which embeddings of both policies and environments can be learned. Then, a value function conditioned on both embeddings is trained. At test time, a few actions are sufficient to infer the environment embedding, enabling a policy to be selected by maximizing the learned value function (which requires no additional environment interaction). We show that our method can rapidly adapt to new dynamics on a set of MuJoCo domains. }, 18 | author = {Raileanu, Roberta and Goldstein, Max and Szlam, Arthur and Rob Fergus, Facebook}, 19 | booktitle = {Proceedings of Machine Learning and Systems 2020}, 20 | pages = {7078--7089}, 21 | title = {Fast Adaptation to New Environments via Policy-Dynamics Value Functions}, 22 | year = {2020} 23 | } 24 | ``` 25 | 26 | ## Requirements 27 | ``` 28 | conda create -n pdvf python=3.7 29 | conda activate pdvf 30 | 31 | git clone git@github.com:rraileanu/policy-dynamics-value-functions.git 32 | cd policy-dynamics-value-functions 33 | pip install -r requirements.txt 34 | 35 | cd myant 36 | pip install -e . 37 | 38 | cd myswimmer 39 | pip install -e . 40 | 41 | cd myspaceship 42 | pip install -e . 43 | ``` 44 | 45 | ## (1) Reinforcement Learning Phase 46 | 47 | Train PPO policies on each environments, one seed for each. 48 | 49 | Each of the commands below need to be run 50 | for seed in [0,...,4] and for default-ind in [0,...,19]. 51 | 52 | ### Spaceship 53 | ``` 54 | python ppo/ppo_main.py \ 55 | --env-name spaceship-v0 --default-ind 0 --seed 0 56 | ``` 57 | 58 | ### Swimmer 59 | ``` 60 | python ppo/ppo_main.py \ 61 | --env-name myswimmer-v0 --default-ind 0 --seed 0 62 | ``` 63 | 64 | ### Ant-wind 65 | ``` 66 | python ppo/ppo_main.py \ 67 | --env-name myant-v0 --default-ind 0 --seed 0 68 | ``` 69 | 70 | ## (2) Self-Supervised Learning Phase 71 | 72 | ## Dynamics Embedding 73 | 74 | ### Spaceship 75 | ``` 76 | python train_dynamics_embedding.py \ 77 | --env-name spaceship-v0 \ 78 | --dynamics-embedding-dim 8 --dynamics-batch-size 8 \ 79 | --inf-num-steps 1 --num-dec-traj 10 \ 80 | --save-dir-dynamics-embedding ./models/dynamics-embeddings 81 | ``` 82 | 83 | ### Swimmer 84 | ``` 85 | python train_dynamics_embedding.py \ 86 | --env-name myswimmer-v0 \ 87 | --dynamics-embedding-dim 2 --dynamics-batch-size 32 \ 88 | --inf-num-steps 1 --num-dec-traj 10 \ 89 | --save-dir-dynamics-embedding ./models/dynamics-embeddings 90 | ``` 91 | 92 | ### Ant-wind 93 | ``` 94 | python train_dynamics_embedding.py \ 95 | --env-name myant-v0 \ 96 | --dynamics-embedding-dim 8 --dynamics-batch-size 32 \ 97 | --inf-num-steps 2 --num-dec-traj 10 \ 98 | --save-dir-dynamics-embedding ./models/dynamics-embeddings 99 | ``` 100 | 101 | ## Policy Embedding 102 | 103 | ### Spaceship 104 | ``` 105 | python train_policy_embedding.py \ 106 | --env-name spaceship-v0 --num-dec-traj 1 \ 107 | --save-dir-policy-embedding ./models/policy-embeddings 108 | ``` 109 | 110 | ### Swimmer 111 | ``` 112 | python train_policy_embedding.py \ 113 | --env-name myswimmer-v0 --num-dec-traj 1 \ 114 | --save-dir-policy-embedding ./models/policy-embeddings 115 | ``` 116 | 117 | ### Ant-wind 118 | ``` 119 | python train_policy_embedding.py \ 120 | --env-name myant-v0 --num-dec-traj 1 \ 121 | --save-dir-policy-embedding ./models/policy-embeddings 122 | ``` 123 | 124 | ## (3) Supervised Learning Phase 125 | 126 | ### Spaceship 127 | ``` 128 | python train_pdvf.py \ 129 | --env-name spaceship-v0 \ 130 | --dynamics-batch-size 8 --policy-batch-size 2048 \ 131 | --dynamics-embedding-dim 8 --policy-embedding-dim 8 \ 132 | --num-dec-traj 10 --inf-num-steps 1 --log-interval 10 \ 133 | --save-dir-dynamics-embedding ./models/dynamics-embeddings \ 134 | --save-dir-policy-embedding ./models/policy-embeddings \ 135 | --save-dir-pdvf ./models/pdvf-models 136 | ``` 137 | 138 | ### Swimmer 139 | ``` 140 | python train_pdvf.py \ 141 | --env-name myswimmer-v0 \ 142 | --dynamics-batch-size 8 --policy-batch-size 2048 \ 143 | --dynamics-embedding-dim 2 --policy-embedding-dim 8 \ 144 | --num-dec-traj 10 --inf-num-steps 1 --log-interval 10 \ 145 | --norm-reward --min-reward -60 --max-reward 200 \ 146 | --save-dir-dynamics-embedding ./models/dynamics-embeddings \ 147 | --save-dir-policy-embedding ./models/policy-embeddings \ 148 | --save-dir-pdvf ./models/pdvf-models 149 | ``` 150 | 151 | ### Ant-wind 152 | ``` 153 | python train_pdvf.py \ 154 | --env-name myant-v0 \ 155 | --dynamics-batch-size 32 --policy-batch-size 2048 \ 156 | --dynamics-embedding-dim 8 --policy-embedding-dim 8 \ 157 | --num-dec-traj 10 --inf-num-steps 2 --log-interval 10 \ 158 | --norm-reward --min-reward -400 --max-reward 1000 \ 159 | --save-dir-dynamics-embedding ./models/dynamics-embeddings \ 160 | --save-dir-policy-embedding ./models/policy-embeddings \ 161 | --save-dir-pdvf ./models/pdvf-models 162 | ``` 163 | 164 | ## (4) Evaluation Phase 165 | 166 | ### Spaceship 167 | ``` 168 | python eval_pdvf.py \ 169 | --env-name spaceship-v0 --stage 20 \ 170 | --dynamics-batch-size 8 --policy-batch-size 2048 \ 171 | --dynamics-embedding-dim 8 --policy-embedding-dim 8 \ 172 | --num-dec-traj 10 --inf-num-steps 1 --log-interval 10 \ 173 | --save-dir-dynamics-embedding ./models/dynamics-embeddings \ 174 | --save-dir-policy-embedding ./models/policy-embeddings \ 175 | --save-dir-pdvf ./models/pdvf-models 176 | ``` 177 | 178 | ### Swimmer 179 | ``` 180 | python eval_pdvf.py \ 181 | --env-name myswimmer-v0 --stage 20 \ 182 | --dynamics-batch-size 8 --policy-batch-size 2048 \ 183 | --dynamics-embedding-dim 2 --policy-embedding-dim 8 \ 184 | --num-dec-traj 10 --inf-num-steps 1 --log-interval 10 \ 185 | --norm-reward --min-reward -60 --max-reward 200 \ 186 | --save-dir-dynamics-embedding ./models/dynamics-embeddings \ 187 | --save-dir-policy-embedding ./models/policy-embeddings \ 188 | --save-dir-pdvf ./models/pdvf-models 189 | ``` 190 | 191 | ### Ant-wind 192 | ``` 193 | python eval_pdvf.py \ 194 | --env-name myant-v0 --stage 20 \ 195 | --dynamics-batch-size 32 --policy-batch-size 2048 \ 196 | --dynamics-embedding-dim 8 --policy-embedding-dim 8 \ 197 | --num-dec-traj 10 --inf-num-steps 2 --log-interval 10 \ 198 | --norm-reward --min-reward -400 --max-reward 1000 \ 199 | --save-dir-dynamics-embedding ./models/dynamics-embeddings \ 200 | --save-dir-policy-embedding ./models/policy-embeddings \ 201 | --save-dir-pdvf ./models/pdvf-models 202 | ``` 203 | 204 | ## Results 205 | ![Performance on Test Environment](/figures/results_pdvf.png) 206 | 207 | ![Ablations](/figures/ablations_pdvf.png) 208 | 209 | 210 | -------------------------------------------------------------------------------- /embedding_networks.py: -------------------------------------------------------------------------------- 1 | """ 2 | Models for the policy and dynamics embeddings / autoencoders. 3 | 4 | The encoder is a transformer and the code is built on top 5 | of the following open sourced implementation: 6 | https://github.com/jadore801120/attention-is-all-you-need-pytorch/ 7 | """ 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import math, copy 12 | from ppo.utils import init 13 | 14 | 15 | def clones(module, N): 16 | "Produce N identical layers." 17 | return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) 18 | 19 | 20 | def init_weights(model): 21 | for p in model.parameters(): 22 | if p.dim() > 1: 23 | nn.init.xavier_uniform_(p) 24 | 25 | 26 | def attention(query, key, value, mask=None, dropout=None): 27 | "Compute 'Scaled Dot Product Attention'" 28 | d_k = query.size(-1) 29 | scores = torch.matmul(query, key.transpose(-2, -1)) \ 30 | / math.sqrt(d_k) 31 | if mask is not None: 32 | scores = scores.masked_fill(mask == 0, -1e9) 33 | p_attn = F.softmax(scores, dim = -1) 34 | if dropout is not None: 35 | p_attn = dropout(p_attn) 36 | return torch.matmul(p_attn, value), p_attn 37 | 38 | 39 | def make_encoder_oh(input_size, N=1, d_model=320, h=5, dropout=0.1, \ 40 | d_emb=8, use_extra_fc=True, no_norm=False): 41 | "Helper: Construct a model from hyperparameters." 42 | c = copy.deepcopy 43 | attn = MultiHeadedAttention(h, d_model) 44 | model = EmbedEncoder( 45 | Encoder(EncoderLayer(d_model, c(attn), dropout, no_norm=no_norm), \ 46 | d_model=d_model, d_emb=d_emb, use_extra_fc=use_extra_fc, no_norm=no_norm 47 | ), 48 | LinearEmb(d_model, input_size) 49 | ) 50 | 51 | for p in model.parameters(): 52 | if p.dim() > 1: 53 | nn.init.xavier_uniform_(p) 54 | return model 55 | 56 | 57 | class EmbedEncoder(nn.Module): 58 | """ 59 | A standard Encoder-Decoder architecture. Base for this and many 60 | other models. 61 | """ 62 | def __init__(self, encoder, src_embed): 63 | super(EmbedEncoder, self).__init__() 64 | self.encoder = encoder 65 | self.src_embed = src_embed 66 | 67 | def forward(self, src, src_mask): 68 | return self.encoder(self.src_embed(src), src_mask) 69 | 70 | 71 | class LinearEmb(nn.Module): 72 | def __init__(self, d_model, input_size): 73 | super(LinearEmb, self).__init__() 74 | self.lin_emb = nn.Linear(input_size, d_model) 75 | 76 | def forward(self, x): 77 | return self.lin_emb(x.float()) 78 | 79 | 80 | class Encoder(nn.Module): 81 | "Core encoder is a stack of N layers" 82 | def __init__(self, layer, d_model=320, d_emb=32, use_extra_fc=False, no_norm=False): 83 | super(Encoder, self).__init__() 84 | self.layer = layer 85 | self.norm = LayerNorm(layer.size) 86 | self.use_extra_fc = use_extra_fc 87 | self.no_norm = no_norm 88 | if self.use_extra_fc: 89 | self.fc_out = nn.Linear(d_model, d_emb) 90 | 91 | def forward(self, x, mask, use_extra_fc=False): 92 | "Pass the input (and mask) through each layer in turn." 93 | x = self.layer(x, mask) 94 | if not self.no_norm: 95 | x = self.norm(x) 96 | if self.use_extra_fc: 97 | x = self.fc_out(x.squeeze(1).mean(1)) 98 | return x 99 | 100 | 101 | class LayerNorm(nn.Module): 102 | "Construct a layernorm module (See citation for details)." 103 | def __init__(self, features, eps=1e-6): 104 | super(LayerNorm, self).__init__() 105 | self.a_2 = nn.Parameter(torch.ones(features)) 106 | self.b_2 = nn.Parameter(torch.zeros(features)) 107 | self.eps = eps 108 | 109 | def forward(self, x): 110 | mean = x.mean(-1, keepdim=True) 111 | std = x.std(-1, keepdim=True) 112 | return self.a_2 * (x - mean) / (std + self.eps) + self.b_2 113 | 114 | 115 | class SublayerConnection(nn.Module): 116 | """ 117 | A residual connection followed by a layer norm. 118 | Note for code simplicity the norm is first as opposed to last. 119 | """ 120 | def __init__(self, size, dropout): 121 | super(SublayerConnection, self).__init__() 122 | self.norm = LayerNorm(size) 123 | self.dropout = nn.Dropout(dropout) 124 | 125 | def forward(self, x, sublayer): 126 | "Apply residual connection to any sublayer with the same size." 127 | return self.dropout(sublayer(self.norm(x))) 128 | 129 | 130 | class EncoderLayer(nn.Module): 131 | "Encoder is made up of self-attn and feed forward (defined below)" 132 | def __init__(self, size, self_attn, dropout, no_norm=False): 133 | super(EncoderLayer, self).__init__() 134 | self.self_attn = self_attn 135 | self.sublayer = SublayerConnection(size, dropout) 136 | self.size = size 137 | self.no_norm = no_norm 138 | 139 | def forward(self, x, mask): 140 | "Follow Figure 1 (left) for connections." 141 | if self.no_norm: 142 | x = self.self_attn(x, x, x, mask) 143 | else: 144 | x = self.sublayer(x, lambda x: self.self_attn(x, x, x, mask)) 145 | return x 146 | 147 | 148 | class MultiHeadedAttention(nn.Module): 149 | def __init__(self, h, d_model, dropout=0.1): 150 | "Take in model size and number of heads." 151 | super(MultiHeadedAttention, self).__init__() 152 | assert d_model % h == 0 153 | # We assume d_v always equals d_k 154 | self.d_k = d_model // h 155 | self.h = h 156 | self.linears = clones(nn.Linear(d_model, d_model), 4) 157 | self.attn = None 158 | self.dropout = nn.Dropout(p=dropout) 159 | 160 | def forward(self, query, key, value, mask=None): 161 | "Implements Figure 2" 162 | if mask is not None: 163 | # Same mask applied to all h heads. 164 | mask = mask.unsqueeze(1) 165 | nbatches = query.size(0) 166 | 167 | # 1) Do all the linear projections in batch from d_model => h x d_k 168 | query, key, value = \ 169 | [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2) 170 | for l, x in zip(self.linears, (query, key, value))] 171 | 172 | # 2) Apply attention on all the projected vectors in batch. 173 | x, self.attn = attention(query, key, value, mask=mask, 174 | dropout=self.dropout) 175 | 176 | # 3) "Concat" using a view and apply a final linear. 177 | x = x.transpose(1, 2).contiguous() \ 178 | .view(nbatches, -1, self.h * self.d_k) 179 | 180 | return self.linears[-1](x) 181 | 182 | 183 | class DecoderSpaceship(nn.Module): 184 | def __init__(self, input_size, hidden_size, output_size, device='cuda'): 185 | super(DecoderSpaceship, self).__init__() 186 | 187 | self.device = device 188 | self.input_size = input_size 189 | self.hidden_size = hidden_size 190 | self.output_size = output_size 191 | 192 | self.fc_in = nn.Linear(input_size, hidden_size) 193 | self.fc_mid = nn.Linear(hidden_size, hidden_size) 194 | 195 | self.fc_out_x = nn.Linear(hidden_size, output_size) 196 | self.fc_out_y = nn.Linear(hidden_size, output_size) 197 | 198 | def forward(self, input, time=None): 199 | input = torch.cat((input, time), dim=1) 200 | output = F.relu(self.fc_in(input)) 201 | output = F.relu(self.fc_mid(output)) 202 | 203 | params_x = self.fc_out_x(output) 204 | params_y = self.fc_out_y(output) 205 | 206 | shape = params_x.shape[0] 207 | act_x = torch.randn(shape).to(self.device)*params_x[:, 1] + params_x[:, 0] 208 | act_y = torch.randn(shape).to(self.device)*params_y[:, 1] + params_y[:, 0] 209 | 210 | act = torch.cat((act_x.unsqueeze(1), act_y.unsqueeze(1)), dim=1) 211 | 212 | return act 213 | 214 | 215 | class DecoderMujoco(nn.Module): 216 | def __init__(self, input_size, hidden_size, output_size, device='cuda'): 217 | super(DecoderMujoco, self).__init__() 218 | 219 | self.device = device 220 | self.input_size = input_size 221 | self.hidden_size = hidden_size 222 | self.output_size = output_size 223 | 224 | self.fc_in = nn.Linear(input_size, hidden_size) 225 | self.fc_mid = nn.Linear(hidden_size, hidden_size) 226 | 227 | self.fc_out = nn.Linear(hidden_size, output_size) 228 | 229 | def forward(self, input, time=None): 230 | input = torch.cat((input, time), dim=1) 231 | output = F.relu(self.fc_in(input)) 232 | output = F.relu(self.fc_mid(output)) 233 | next_state = self.fc_out(output) 234 | return next_state 235 | 236 | 237 | -------------------------------------------------------------------------------- /env_utils.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import myspaceship 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | 7 | MAX_EPISODE_STEPS_ANT = 256 8 | MAX_EPISODE_STEPS_SWIMMER = 1000 9 | MAX_EPISODE_STEPS_SPACESHIP = 50 10 | 11 | 12 | def make_one_hot(x, nb_digits=3, batch_size=1): 13 | ''' 14 | Convert int to one hot tensor 15 | ''' 16 | y = x.reshape(batch_size, 1) 17 | 18 | # One hot encoding buffer that you create out of the loop and just keep reusing 19 | y_onehot = torch.FloatTensor(batch_size, nb_digits) 20 | 21 | # In your for loop 22 | y_onehot.zero_() 23 | y_onehot.scatter_(1, y.long(), 1) 24 | 25 | return y_onehot 26 | 27 | 28 | class EnvSamplerEmb(): 29 | ''' 30 | Environment sampler object for training the 31 | policy and dynamics embeddings. 32 | ''' 33 | def __init__(self, env, base_policy, args): 34 | if args.env_name.startswith('spaceship'): 35 | self.env = env.venv.envs[0] 36 | elif args.env_name.startswith('myswimmer'): 37 | self.env = env.venv.envs[0].env.env 38 | else: 39 | self.env = env.venv.envs[0].env.env.env 40 | self.env.reset() 41 | self.env1 = env 42 | self.base_policy = base_policy 43 | self.args = args 44 | self.env_name = args.env_name 45 | self.action_dim = self.env1.action_space.shape[0] 46 | self.state_dim = self.env1.observation_space.shape[0] 47 | self.enc_input_size = 2*self.state_dim + self.action_dim 48 | self.inf_num_steps = args.inf_num_steps 49 | self.env_max_seq_length = args.inf_num_steps * self.enc_input_size 50 | self.max_seq_length = args.max_num_steps * self.enc_input_size 51 | 52 | if 'ant' in args.env_name: 53 | self.max_episode_steps = MAX_EPISODE_STEPS_ANT 54 | elif 'swimmer' in args.env_name: 55 | self.max_episode_steps = MAX_EPISODE_STEPS_SWIMMER 56 | elif 'spaceship' in args.env_name: 57 | self.max_episode_steps = MAX_EPISODE_STEPS_SPACESHIP 58 | 59 | def reset(self): 60 | ''' 61 | Reset the environment to an environment that might have different dynamics. 62 | ''' 63 | self.env1.reset() 64 | 65 | def reset_same(self): 66 | ''' 67 | Reset the environment to an environment with identical dynamics. 68 | ''' 69 | return self.env.reset(same=True) 70 | 71 | def sample_env_context(self, policy_fn, env_idx=None): 72 | ''' 73 | Generate a few steps in the new environment. 74 | Use these transitions to infer the environment's dynamics. 75 | ''' 76 | src = [] 77 | this_src_batch = [] 78 | if env_idx is not None: 79 | state = self.env.reset(env_id=env_idx) 80 | else: 81 | state = self.reset_same() 82 | state = torch.tensor(state).float() 83 | done = False 84 | 85 | max_num_steps = self.inf_num_steps 86 | 87 | for t in range(max_num_steps): 88 | recurrent_hidden_state = torch.zeros( 89 | 1, policy_fn.recurrent_hidden_state_size, device=self.args.device) 90 | mask = torch.zeros(1, 1, device=self.args.device) 91 | policy_fn = policy_fn.float() 92 | action = policy_fn.act( 93 | state.squeeze().unsqueeze(0).to(torch.device(self.args.device)).float(), 94 | recurrent_hidden_state.float(), mask.float(), 95 | deterministic=True)[1].detach() 96 | 97 | action_tensor = action.float().reshape(self.action_dim) 98 | state_action_tensor = torch.cat((state.squeeze().to(device=self.args.device), 99 | action_tensor.reshape(self.action_dim)), dim=0) 100 | if not done: 101 | next_state, _, done, _ = self.env.step(*action.cpu().numpy()) 102 | next_state_tensor = torch.FloatTensor(next_state) 103 | state = next_state_tensor 104 | state_action_state_tensor = torch.cat((state_action_tensor.to(device=self.args.device), \ 105 | next_state_tensor.to(device=self.args.device)), dim=0).unsqueeze(1) 106 | else: 107 | state_action_state_tensor = torch.FloatTensor([0 for _ in range(self.enc_input_size)])\ 108 | .reshape(self.enc_input_size, 1) 109 | src.append(state_action_state_tensor) 110 | src_pad = F.pad(torch.stack(src).reshape((t + 1) * self.enc_input_size), \ 111 | (0, self.max_seq_length - (t + 1) * len(state_action_state_tensor))) 112 | 113 | this_src_batch.append(src_pad) 114 | 115 | return this_src_batch[-1] 116 | 117 | def sample_policy(self, policy_idx=0): 118 | ''' 119 | Sample a policy from your set of pretrained policies. 120 | ''' 121 | return self.base_policy[policy_idx] 122 | 123 | def generic_step(self, policy_fn, state): 124 | ''' 125 | Take a step in the environment 126 | ''' 127 | action = policy_fn(state.squeeze().unsqueeze(0)) 128 | action_tensor = torch.FloatTensor(action) 129 | state_action_tensor = torch.cat((state.squeeze(), 130 | action_tensor.squeeze()), dim=0) 131 | 132 | next_state, reward, done, _ = self.env.step(action.numpy()) 133 | next_state_tensor = torch.Tensor(next_state).float() 134 | if not done: 135 | sas_tensor = torch.cat((state_action_tensor, next_state_tensor), dim=0).unsqueeze(1) 136 | else: 137 | sas_tensor = torch.FloatTensor([0 for _ in range(self.enc_input_size)]) \ 138 | .reshape(self.enc_input_size, 1) 139 | 140 | res = {'state': next_state_tensor, 'action': action_tensor, 'sa_tensor': state_action_tensor, 141 | 'sas_tensor': sas_tensor, 'reward': reward, 'done': done} 142 | return res 143 | 144 | def get_decoded_traj(self, args, init_state, init_obs, policy_emb, decoder, env_idx=0, verbose=False): 145 | ''' 146 | Decode a trajectory using the policy decoder conditioned on a given policy embedding, 147 | in a given environment for a given initial state. Works with the Spaceship environment. 148 | ''' 149 | state = self.env.reset(env_id=env_idx) 150 | state = torch.FloatTensor(state) 151 | 152 | device = args.device 153 | done = False 154 | episode_reward = 0 155 | 156 | all_emb_state = [] 157 | all_recurrent_state = [] 158 | all_mask = [] 159 | all_action = [] 160 | for t in range(args.max_num_steps): 161 | recurrent_hidden_state = torch.zeros(policy_emb.shape[0], 162 | decoder.recurrent_hidden_state_size, device=device, requires_grad=True).float() 163 | mask_dec = torch.zeros(policy_emb.shape[0], 1, device=device, requires_grad=True).float() 164 | emb_state_input = torch.cat((policy_emb.to(device), state.unsqueeze(0).to(device)), dim=1).to(device) 165 | 166 | action = decoder.act(emb_state_input, recurrent_hidden_state, mask_dec, 167 | deterministic=True)[1] 168 | 169 | action_flat = action.squeeze().cpu().detach().numpy() 170 | action = action.cpu().detach().numpy() 171 | 172 | all_emb_state.append(emb_state_input) 173 | all_recurrent_state.append(recurrent_hidden_state) 174 | all_mask.append(mask_dec) 175 | all_action.append(torch.tensor(action).to(device)) 176 | 177 | next_state, reward, done, _ = self.env.step(action_flat) 178 | episode_reward = args.gamma * episode_reward + reward 179 | state = torch.FloatTensor(next_state) 180 | if done: 181 | break 182 | 183 | return all_emb_state, all_recurrent_state, all_mask, all_action 184 | 185 | def get_decoded_traj_mujoco(self, args, init_state, init_obs, policy_emb, decoder, env_idx=0, verbose=False): 186 | ''' 187 | Decode a trajectory using the policy decoder conditioned on a given policy embedding, 188 | in a given environment for a given initial state. Works with MuJoCo environments (Ant and Swimmer). 189 | ''' 190 | self.env.reset(env_id=env_idx) 191 | self.env.sim.set_state(init_state) 192 | state = init_obs 193 | 194 | device = args.device 195 | done = False 196 | episode_reward = 0 197 | t = 0 198 | 199 | all_emb_state = [] 200 | all_recurrent_state = [] 201 | all_mask = [] 202 | all_action = [] 203 | for t in range(self.max_episode_steps): 204 | t += 1 205 | recurrent_hidden_state = torch.zeros(policy_emb.shape[0], 206 | decoder.recurrent_hidden_state_size, device=device, requires_grad=True).float() 207 | mask_dec = torch.zeros(policy_emb.shape[0], 1, device=device, requires_grad=True).float() 208 | emb_state_input = torch.cat((policy_emb.to(device), state.unsqueeze(0).to(device)), dim=1).to(device) 209 | action = decoder.act(emb_state_input, recurrent_hidden_state, mask_dec, 210 | deterministic=True)[1] 211 | 212 | action_flat = action.squeeze().cpu().detach().numpy() 213 | action = action.cpu().detach().numpy() 214 | 215 | all_emb_state.append(emb_state_input) 216 | all_recurrent_state.append(recurrent_hidden_state) 217 | all_mask.append(mask_dec) 218 | all_action.append(torch.tensor(action).to(device)) 219 | 220 | next_state, reward, done, _ = self.env.step(action_flat) 221 | episode_reward = args.gamma * episode_reward + reward 222 | state = torch.FloatTensor(next_state) 223 | if done: 224 | break 225 | 226 | if args.norm_reward: 227 | episode_reward = (episode_reward - args.min_reward) / (args.max_reward - args.min_reward) 228 | 229 | return all_emb_state, all_recurrent_state, all_mask, all_action 230 | 231 | def sample_policy_data(self, policy_idx=0, env_idx=None): 232 | ''' 233 | Sample data using a given policy. 234 | ''' 235 | done = False 236 | state_batch = [] 237 | tgt_batch = [] 238 | src_batch = [] 239 | mask_batch = [] 240 | mask_batch_all = [] 241 | src = [] 242 | masks = [] 243 | 244 | if env_idx is not None: 245 | init_state = self.env.reset(env_id=env_idx) 246 | else: 247 | init_state = self.env.reset() 248 | 249 | trajectory = self.sample_policy(policy_idx=policy_idx) 250 | state_tensor = torch.tensor(init_state).unsqueeze(0) 251 | recurrent_hidden_state = torch.zeros( 252 | 1, trajectory.recurrent_hidden_state_size, device=self.args.device) 253 | mask = torch.zeros(1, 1, device=self.args.device) 254 | trajectory = trajectory.float().to(self.args.device) 255 | action = trajectory.act( 256 | state_tensor.to(torch.device(self.args.device)).float(), 257 | recurrent_hidden_state.float(), mask.float(), 258 | deterministic=True)[1] 259 | action_tensor = action.float().reshape(self.action_dim) 260 | state_action_tensor = torch.cat([ 261 | state_tensor.to(torch.device(self.args.device)).float().squeeze(), 262 | action_tensor], dim=0).unsqueeze(1) 263 | 264 | for t in range(self.args.max_num_steps): 265 | state_batch.append(state_tensor) 266 | src.append(state_action_tensor) 267 | masks.append(torch.FloatTensor([done == False])) 268 | mask_batch.append(torch.FloatTensor([done == False])) 269 | 270 | state_tensor, _, done, _ = self.env.step(action_tensor.cpu().detach().numpy()) 271 | 272 | tgt_batch.append(action_tensor.detach()) 273 | 274 | 275 | state_tensor = torch.tensor(state_tensor).unsqueeze(0) 276 | recurrent_hidden_state = torch.zeros( 277 | 1, trajectory.recurrent_hidden_state_size, device=self.args.device) 278 | mask = torch.zeros(1, 1, device=self.args.device) 279 | trajectory = trajectory.float().to(self.args.device) 280 | action = trajectory.act( 281 | state_tensor.to(torch.device(self.args.device)).float(), 282 | recurrent_hidden_state.float(), mask.float(), 283 | deterministic=True)[1].detach() 284 | action_tensor = action.float().reshape(self.action_dim) 285 | state_action_tensor = torch.cat([ 286 | state_tensor.to(torch.device(self.args.device)).float().squeeze(), 287 | action_tensor], dim=0).unsqueeze(1) 288 | 289 | for t in range(self.args.max_num_steps): 290 | src_tensor = torch.stack(src).squeeze(2) 291 | src_batch.append(src_tensor) 292 | mask_batch_all.append(torch.stack(masks)) 293 | 294 | return state_batch, tgt_batch, src_batch, mask_batch, mask_batch_all 295 | 296 | def sample_k_traj_zeroshot(self, k_traj, policy_idx=0, env_idx=None): 297 | ''' 298 | Sample a number of trajectories using the inferred dynamics 299 | from a small number of interactions with a new environment. 300 | ''' 301 | trajectory = self.sample_policy(policy_idx=policy_idx) 302 | context_env = self.sample_env_context(trajectory, env_idx=env_idx) 303 | state_action_list = [] 304 | target_list = [] 305 | source_list = [] 306 | 307 | for k in range(k_traj): 308 | eval_episode_rewards = [] 309 | init_state = torch.tensor(self.env.reset(same=True)).float() 310 | state = init_state 311 | obs = init_state 312 | trajectory = self.sample_policy(policy_idx=policy_idx) 313 | for _ in range(self.inf_num_steps): 314 | obs_feat = obs.float().squeeze().reshape(1,-1).to(torch.device(self.args.device)) 315 | 316 | recurrent_hidden_state = torch.zeros( 317 | 1, trajectory.recurrent_hidden_state_size, device=self.args.device) 318 | mask = torch.zeros(1, 1, device=self.args.device) 319 | trajectory = trajectory.float() 320 | action = trajectory.act( 321 | obs_feat, recurrent_hidden_state.float(), mask.float(), 322 | deterministic=True)[1].detach() 323 | 324 | action_tensor = action.float().reshape(self.action_dim) 325 | state_action_tensor = torch.cat([obs.squeeze().to(device=self.args.device), 326 | action_tensor.to(device=self.args.device)], dim=0) 327 | obs, reward, done, infos = self.env.step(*action.cpu().numpy()) 328 | 329 | obs = torch.tensor(obs).float() 330 | target_tensor = obs 331 | 332 | state_action_list.append(state_action_tensor) 333 | target_list.append(target_tensor) 334 | source_list.append(context_env.reshape(self.args.max_num_steps, -1)) 335 | 336 | return state_action_list, target_list, source_list 337 | 338 | 339 | class EnvSamplerPDVF(): 340 | ''' 341 | Environment sampler object for training the 342 | Policy-Dynamics Value Function. 343 | ''' 344 | def __init__(self, env, base_policy, args): 345 | if args.env_name.startswith('spaceship'): 346 | self.env = env.venv.envs[0] 347 | elif args.env_name.startswith('myswimmer'): 348 | self.env = env.venv.envs[0].env.env 349 | else: 350 | self.env = env.venv.envs[0].env.env.env 351 | self.env.reset() 352 | self.env1 = env 353 | self.base_policy = base_policy 354 | self.args = args 355 | self.env_name = args.env_name 356 | self.action_dim = self.env1.action_space.shape[0] 357 | self.state_dim = self.env1.observation_space.shape[0] 358 | self.enc_input_size = 2 * self.state_dim + self.action_dim 359 | self.inf_num_steps = args.inf_num_steps 360 | self.env_max_seq_length = args.inf_num_steps * self.enc_input_size 361 | self.max_seq_length = args.max_num_steps * self.enc_input_size 362 | 363 | if 'ant' in args.env_name: 364 | self.max_episode_steps = MAX_EPISODE_STEPS_ANT 365 | elif 'swimmer' in args.env_name: 366 | self.max_episode_steps = MAX_EPISODE_STEPS_SWIMMER 367 | elif 'spaceship' in args.env_name: 368 | self.max_episode_steps = MAX_EPISODE_STEPS_SPACESHIP 369 | 370 | def reset(self): 371 | ''' 372 | Reset the environment to an environment that might have different dynamics. 373 | ''' 374 | return self.env1.reset() 375 | 376 | def reset_same(self): 377 | ''' 378 | Reset the environment to an environment with identical dynamics. 379 | ''' 380 | return self.env.reset(same=True) 381 | 382 | def sample_policy(self, policy_idx=0): 383 | ''' 384 | Sample a policy from your set of pretrained policies. 385 | ''' 386 | return self.base_policy[policy_idx] 387 | 388 | def generic_step(self, policy_fn, state): 389 | ''' 390 | Take a step in the environment 391 | ''' 392 | recurrent_hidden_state = torch.zeros( 393 | 1, policy_fn.recurrent_hidden_state_size, device=self.args.device) 394 | mask = torch.zeros(1, 1, device=self.args.device) 395 | policy_fn = policy_fn.float().to(self.args.device) 396 | action = policy_fn.act( 397 | state.squeeze().unsqueeze(0).to(torch.device(self.args.device)).float(), 398 | recurrent_hidden_state.float(), mask.float(), 399 | deterministic=True)[1].detach() 400 | action_tensor = action.float().reshape(self.action_dim) 401 | state_action_tensor = torch.cat([state.float().squeeze().to(self.args.device), 402 | action_tensor.to(self.args.device)], dim=0) 403 | next_state, reward, done, _ = self.env.step(*action.cpu().numpy()) 404 | 405 | next_state_tensor = torch.Tensor(next_state).float().to(self.args.device) 406 | if not done: 407 | sas_tensor = torch.cat((state_action_tensor, next_state_tensor), dim=0).unsqueeze(1) 408 | else: 409 | sas_tensor = torch.FloatTensor([0 for _ in range(self.enc_input_size)]) \ 410 | .reshape(self.enc_input_size, 1) 411 | 412 | res = {'next_state': next_state_tensor, 'action': action_tensor, 'sa_tensor': state_action_tensor, 413 | 'sas_tensor': sas_tensor.to(self.args.device), 'reward': reward, 'done': done} 414 | return res 415 | 416 | def sample_env_context(self, policy_fn, env_idx=None): 417 | ''' 418 | Generate a few steps in the new environment. 419 | Use these transitions to infer the environment's dynamics. 420 | ''' 421 | src = [] 422 | this_src_batch = [] 423 | if env_idx is not None: 424 | state = self.env.reset(env_id=env_idx) 425 | else: 426 | state = self.reset_same() 427 | state = torch.tensor(state).float() 428 | done = False 429 | 430 | max_num_steps = self.inf_num_steps 431 | 432 | for t in range(max_num_steps): 433 | res = self.generic_step(policy_fn, state) 434 | state, done, sas_tensor = res['next_state'], res['done'], res['sas_tensor'] 435 | src.append(sas_tensor) 436 | src_pad = F.pad(torch.stack(src).reshape((t + 1) * self.enc_input_size), \ 437 | (0, self.env_max_seq_length - (t + 1) * len(sas_tensor))) 438 | 439 | this_src_batch.append(src_pad) 440 | if done: 441 | break 442 | 443 | res = this_src_batch[-1].reshape(max_num_steps, -1) 444 | return res 445 | 446 | def zeroshot_sample_src_from_pol_state(self, args, init_state, sizes, policy_idx=0, env_idx=0): 447 | ''' 448 | Sample transitions using a certain policy and starting in a given state. 449 | Works for Spaceship. 450 | ''' 451 | # get the policy embedding 452 | src_policy = [] 453 | masks_policy = [] 454 | 455 | episode_reward = 0 456 | state = init_state 457 | policy_fn = self.sample_policy(policy_idx=policy_idx) 458 | for t in range(args.max_num_steps): 459 | res = self.generic_step(policy_fn, state) 460 | state, done, reward, sa_tensor, sas_tensor = res['next_state'], res['done'], res['reward'], \ 461 | res['sa_tensor'], res['sas_tensor'] 462 | episode_reward = args.gamma * episode_reward + reward 463 | src_policy.append(sa_tensor) 464 | masks_policy.append(torch.FloatTensor([done == False])) 465 | if done: 466 | break 467 | 468 | policy_feats = torch.stack(src_policy).unsqueeze(0) 469 | mask_policy = torch.stack(masks_policy).squeeze(1).unsqueeze(0).unsqueeze(0) 470 | 471 | if self.env_name.startswith('myacrobot'): 472 | episode_reward += args.max_num_steps 473 | 474 | # get the env embedding 475 | src_env = [] 476 | env_feats = [] 477 | 478 | state = init_state 479 | policy_fn = self.sample_policy(policy_idx=policy_idx) 480 | for t in range(self.inf_num_steps): 481 | res = self.generic_step(policy_fn, state) 482 | state, done, reward, sa_tensor, sas_tensor = res['next_state'], res['done'], res['reward'], \ 483 | res['sa_tensor'], res['sas_tensor'] 484 | env_feats.append(sas_tensor) 485 | env_pad = F.pad(torch.stack(env_feats).reshape((t+1) * self.enc_input_size), \ 486 | (0, self.env_max_seq_length - (t+1) * len(sas_tensor))) 487 | if done: 488 | break 489 | 490 | source_env = torch.stack([env_pad.reshape(self.inf_num_steps,len(env_feats[0]))]) 491 | mask_env = (source_env != 0).unsqueeze(-2) 492 | mask_env = mask_env[:, :, :, 0].squeeze(2).unsqueeze(1) 493 | 494 | res = {'source_env': source_env, 495 | 'source_policy': policy_feats, 496 | 'mask_policy': mask_policy, 497 | 'mask_env': mask_env, 498 | 'episode_reward': episode_reward, 499 | 't': t, 500 | 'init_state': init_state} 501 | 502 | return res 503 | 504 | def zeroshot_sample_src_from_pol_state_mujoco(self, args, init_state, sizes, policy_idx=0, env_idx=0): 505 | ''' 506 | Sample transitions using a certain policy and starting in a given state. 507 | Works for Swimmer and Ant. 508 | ''' 509 | # get the policy embedding 510 | src_policy = [] 511 | masks_policy = [] 512 | 513 | episode_reward = 0 514 | state = init_state 515 | policy_fn = self.sample_policy(policy_idx=policy_idx) 516 | for t in range(args.max_num_steps): 517 | res = self.generic_step(policy_fn, state) 518 | state, done, reward, sa_tensor, sas_tensor = res['next_state'], res['done'], res['reward'], \ 519 | res['sa_tensor'], res['sas_tensor'] 520 | episode_reward = args.gamma * episode_reward + reward 521 | src_policy.append(sa_tensor) 522 | masks_policy.append(torch.FloatTensor([done == False])) 523 | if done: 524 | break 525 | if not done: 526 | for t in range(self.max_episode_steps - args.max_num_steps): 527 | res = self.generic_step(policy_fn, state) 528 | state, done, reward, sa_tensor, sas_tensor = res['next_state'], res['done'], res['reward'], \ 529 | res['sa_tensor'], res['sas_tensor'] 530 | episode_reward = args.gamma * episode_reward + reward 531 | if done: 532 | break 533 | 534 | policy_feats = torch.stack(src_policy).unsqueeze(0) 535 | mask_policy = torch.stack(masks_policy).squeeze(1).unsqueeze(0).unsqueeze(0) 536 | 537 | # get the env embedding 538 | src_env = [] 539 | env_feats = [] 540 | 541 | state = init_state 542 | policy_fn = self.sample_policy(policy_idx=policy_idx) 543 | for t in range(self.inf_num_steps): 544 | res = self.generic_step(policy_fn, state) 545 | state, done, reward, sa_tensor, sas_tensor = res['next_state'], res['done'], res['reward'], \ 546 | res['sa_tensor'], res['sas_tensor'] 547 | env_feats.append(sas_tensor) 548 | env_pad = F.pad(torch.stack(env_feats).reshape((t+1) * self.enc_input_size), \ 549 | (0, self.env_max_seq_length - (t+1) * len(sas_tensor))) 550 | if done: 551 | break 552 | 553 | source_env = torch.stack([env_pad.reshape(self.inf_num_steps, len(env_feats[0]))]) 554 | mask_env = (source_env != 0).unsqueeze(-2) 555 | mask_env = mask_env[:, :, :, 0].squeeze(2).unsqueeze(1) 556 | 557 | if args.norm_reward: 558 | episode_reward = (episode_reward - args.min_reward) / (args.max_reward - args.min_reward) 559 | 560 | res = {'source_env': source_env, 561 | 'source_policy': policy_feats, 562 | 'mask_policy': mask_policy, 563 | 'mask_env': mask_env, 564 | 'episode_reward': episode_reward, 565 | 't': t, 566 | 'init_state': init_state} 567 | 568 | return res 569 | 570 | def get_reward_pol_embedding_state(self, args, init_state, init_obs, policy_emb, decoder, env_idx=0, verbose=False): 571 | ''' 572 | Estimate the return using Monte-Carlo for a given policy embedding starting at a given initial state. 573 | Works for Spaceship. 574 | ''' 575 | self.env.reset(env_id=env_idx) 576 | self.env.state = init_state 577 | state = init_obs 578 | 579 | device = args.device 580 | done = False 581 | episode_reward = 0 582 | for t in range(args.max_num_steps): 583 | recurrent_hidden_state = torch.zeros(policy_emb.shape[0], 584 | decoder.recurrent_hidden_state_size, device=device, requires_grad=True).float() 585 | mask_dec = torch.zeros(policy_emb.shape[0], 1, device=device, requires_grad=True).float() 586 | emb_state_input = torch.cat((policy_emb.to(device), state.unsqueeze(0).to(device)), dim=1).to(device) 587 | action = decoder.act(emb_state_input, recurrent_hidden_state, mask_dec, 588 | deterministic=True)[1].squeeze().cpu().detach().numpy() 589 | next_state, reward, done, _ = self.env.step(action) 590 | 591 | episode_reward = args.gamma * episode_reward + reward 592 | state = torch.FloatTensor(next_state) 593 | if done: 594 | break 595 | return episode_reward, t 596 | 597 | def get_reward_pol_embedding_state_mujoco(self, args, init_state, init_obs, policy_emb, decoder, env_idx=0, verbose=False): 598 | ''' 599 | Estimate the return using Monte-Carlo for a given policy embedding starting at a given initial state. 600 | Works for Swimmer and Ant. 601 | ''' 602 | self.env.reset(env_id=env_idx) 603 | self.env.sim.set_state(init_state) 604 | state = init_obs 605 | 606 | device = args.device 607 | done = False 608 | episode_reward = 0 609 | t = 0 610 | for t in range(self.max_episode_steps): 611 | t += 1 612 | recurrent_hidden_state = torch.zeros(policy_emb.shape[0], 613 | decoder.recurrent_hidden_state_size, device=device, requires_grad=True).float() 614 | mask_dec = torch.zeros(policy_emb.shape[0], 1, device=device, requires_grad=True).float() 615 | emb_state_input = torch.cat((policy_emb.to(device), state.unsqueeze(0).to(device)), dim=1).to(device) 616 | action = decoder.act(emb_state_input, recurrent_hidden_state, mask_dec, 617 | deterministic=True)[1].squeeze().cpu().detach().numpy() 618 | next_state, reward, done, _ = self.env.step(action) 619 | 620 | episode_reward = args.gamma * episode_reward + reward 621 | state = torch.FloatTensor(next_state) 622 | if done: 623 | break 624 | 625 | if args.norm_reward: 626 | episode_reward = (episode_reward - args.min_reward) / (args.max_reward - args.min_reward) 627 | 628 | return episode_reward, t 629 | 630 | def sample_policy_data(self, policy_idx=0, env_idx=None): 631 | ''' 632 | Sample transitions from a given policy in your collection. 633 | ''' 634 | state_batch = [] 635 | tgt_batch = [] 636 | src_batch = [] 637 | mask_batch = [] 638 | mask_batch_all = [] 639 | 640 | if env_idx is not None: 641 | state = self.env.reset(env_id=env_idx) 642 | else: 643 | state = self.env.reset() 644 | 645 | trajectory = self.sample_policy(policy_idx=policy_idx) 646 | state = torch.tensor(state).unsqueeze(0) 647 | res = self.generic_step(trajectory, state) 648 | state, action, done, reward, sa_tensor = res['next_state'], res['action'], res['done'], res['reward'], \ 649 | res['sa_tensor'] 650 | sa_tensor = torch.transpose(sa_tensor, 1, 0) 651 | src = [] 652 | masks = [] 653 | 654 | for t in range(self.args.max_num_steps): 655 | state_batch.append(state) 656 | src.append(sa_tensor) 657 | tgt_batch.append(action.detach()) 658 | res = self.generic_step(trajectory, state) 659 | state, action, done, reward, sa_tensor = res['state'], res['action'], res['done'], res['reward'], \ 660 | res['sa_tensor'] 661 | 662 | sa_tensor = torch.transpose(sa_tensor, 1, 0) 663 | masks.append(torch.FloatTensor([done == False])) 664 | mask_batch.append(torch.FloatTensor([done == False])) 665 | 666 | for t in range(self.args.max_num_steps): 667 | src_tensor = torch.stack(src).squeeze(2) 668 | src_batch.append(src_tensor) 669 | mask_batch_all.append(torch.stack(masks)) 670 | 671 | return state_batch, tgt_batch, src_batch, mask_batch, mask_batch_all 672 | -------------------------------------------------------------------------------- /eval_pdvf.py: -------------------------------------------------------------------------------- 1 | import os, random, sys 2 | import numpy as np 3 | 4 | import torch 5 | import torch.optim as optim 6 | import torch.nn.functional as F 7 | 8 | from pdvf_networks import PDVF 9 | 10 | from pdvf_arguments import get_args 11 | 12 | from ppo.model import Policy 13 | from ppo.envs import make_vec_envs 14 | 15 | import env_utils 16 | import pdvf_utils 17 | import train_utils 18 | 19 | import myant 20 | import myswimmer 21 | import myspaceship 22 | 23 | 24 | def eval_pdvf(): 25 | ''' 26 | Evaluate the Policy-Dynamics Value Function. 27 | ''' 28 | args = get_args() 29 | 30 | torch.manual_seed(args.seed) 31 | torch.cuda.manual_seed_all(args.seed) 32 | 33 | torch.set_num_threads(1) 34 | device = args.device 35 | if device != 'cpu': 36 | torch.cuda.empty_cache() 37 | 38 | if args.cuda and torch.cuda.is_available() and args.cuda_deterministic: 39 | torch.backends.cudnn.benchmark = False 40 | torch.backends.cudnn.deterministic = True 41 | 42 | env = make_vec_envs(args, device) 43 | env.reset() 44 | 45 | names = [] 46 | for e in range(args.num_envs): 47 | for s in range(args.num_seeds): 48 | names.append('ppo.{}.env{}.seed{}.pt'.format(args.env_name, e, s)) 49 | 50 | source_policy = [] 51 | for name in names: 52 | actor_critic = Policy( 53 | env.observation_space.shape, 54 | env.action_space, 55 | base_kwargs={'recurrent': False}) 56 | actor_critic.to(device) 57 | model = os.path.join(args.save_dir, name) 58 | actor_critic.load_state_dict(torch.load(model)) 59 | source_policy.append(actor_critic) 60 | 61 | # Load the collected interaction episodes for each agent 62 | policy_encoder, policy_decoder = pdvf_utils.load_policy_model( 63 | args, env) 64 | env_encoder = pdvf_utils.load_dynamics_model( 65 | args, env) 66 | 67 | value_net = PDVF(env.observation_space.shape[0], args.dynamics_embedding_dim, args.hidden_dim_pdvf, 68 | args.policy_embedding_dim, device=device).to(device) 69 | value_net.to(device) 70 | path_to_pdvf = os.path.join(args.save_dir_pdvf, \ 71 | "pdvf-stage{}.{}.pt".format(args.stage, args.env_name)) 72 | value_net.load_state_dict(torch.load(path_to_pdvf)['state_dict']) 73 | value_net.eval() 74 | 75 | all_envs = [i for i in range(args.num_envs)] 76 | train_policies = [i for i in range(int(3/4*args.num_envs))] 77 | train_envs = [i for i in range(int(3/4*args.num_envs))] 78 | eval_envs = [i for i in range(int(3/4*args.num_envs), args.num_envs)] 79 | 80 | env_enc_input_size = env.observation_space.shape[0] + env.action_space.shape[0] 81 | sizes = pdvf_utils.DotDict({'state_dim': env.observation_space.shape[0], \ 82 | 'action_dim': env.action_space.shape[0], 'env_enc_input_size': \ 83 | env_enc_input_size, 'env_max_seq_length': args.max_num_steps * env_enc_input_size}) 84 | 85 | env_sampler = env_utils.EnvSamplerPDVF(env, source_policy, args) 86 | 87 | all_mean_rewards = [[] for _ in range(args.num_envs)] 88 | all_mean_unnorm_rewards = [[] for _ in range(args.num_envs)] 89 | 90 | # Eval on Train Envs 91 | train_rewards = {} 92 | unnorm_train_rewards = {} 93 | for ei in range(len(all_envs)): 94 | train_rewards[ei] = [] 95 | unnorm_train_rewards[ei] = [] 96 | for ei in train_envs: 97 | for i in range(args.num_eval_eps): 98 | args.seed = i 99 | np.random.seed(seed=i) 100 | torch.manual_seed(args.seed) 101 | torch.cuda.manual_seed_all(args.seed) 102 | for pi in train_policies: 103 | init_obs = torch.FloatTensor(env_sampler.env.reset(env_id=ei)) 104 | if 'ant' in args.env_name or 'swimmer' in args.env_name: 105 | init_state = env_sampler.env.sim.get_state() 106 | res = env_sampler.zeroshot_sample_src_from_pol_state_mujoco(args, init_obs, sizes, policy_idx=pi, env_idx=ei) 107 | else: 108 | init_state = env_sampler.env.state 109 | res = env_sampler.zeroshot_sample_src_from_pol_state(args, init_obs, sizes, policy_idx=pi, env_idx=ei) 110 | 111 | source_env = res['source_env'] 112 | mask_env = res['mask_env'] 113 | source_policy = res['source_policy'] 114 | init_episode_reward = res['episode_reward'] 115 | mask_policy = res['mask_policy'] 116 | 117 | if source_policy.shape[1] == 1: 118 | source_policy = source_policy.repeat(1, 2, 1) 119 | mask_policy = mask_policy.repeat(1, 1, 2) 120 | emb_policy = policy_encoder(source_policy.detach().to(device), 121 | mask_policy.detach().to(device)).detach() 122 | if source_env.shape[1] == 1: 123 | source_env = source_env.repeat(1, 2, 1) 124 | mask_env = mask_env.repeat(1, 1, 2) 125 | emb_env = env_encoder(source_env.detach().to(device), 126 | mask_env.detach().to(device)).detach() 127 | 128 | emb_policy = F.normalize(emb_policy, p=2, dim=1).detach() 129 | emb_env = F.normalize(emb_env, p=2, dim=1).detach() 130 | 131 | pred_value = value_net(init_obs.unsqueeze(0).to(device), 132 | emb_env.to(device), emb_policy.to(device)).item() 133 | if 'ant' in args.env_name or 'swimmer' in args.env_name: 134 | decoded_reward = env_sampler.get_reward_pol_embedding_state_mujoco(args, 135 | init_state, init_obs, emb_policy, policy_decoder, env_idx=ei)[0] 136 | else: 137 | decoded_reward = env_sampler.get_reward_pol_embedding_state(args, 138 | init_state, init_obs, emb_policy, policy_decoder, env_idx=ei)[0] 139 | 140 | qf = value_net.get_qf(init_obs.unsqueeze(0).to(device), emb_env) 141 | u, s, v = torch.svd(qf.squeeze()) 142 | 143 | opt_policy_pos = u[:,0].unsqueeze(0) 144 | opt_policy_neg = -u[:,0].unsqueeze(0) 145 | 146 | if 'ant' in args.env_name or 'swimmer' in args.env_name: 147 | episode_reward_pos, num_steps_pos = env_sampler.get_reward_pol_embedding_state_mujoco( 148 | args, init_state, init_obs, opt_policy_pos, policy_decoder, env_idx=ei) 149 | episode_reward_neg, num_steps_neg = env_sampler.get_reward_pol_embedding_state_mujoco( 150 | args, init_state, init_obs, opt_policy_neg, policy_decoder, env_idx=ei) 151 | else: 152 | episode_reward_pos, num_steps_pos = env_sampler.get_reward_pol_embedding_state( 153 | args, init_state, init_obs, opt_policy_pos, policy_decoder, env_idx=ei) 154 | episode_reward_neg, num_steps_neg = env_sampler.get_reward_pol_embedding_state( 155 | args, init_state, init_obs, opt_policy_neg, policy_decoder, env_idx=ei) 156 | 157 | if episode_reward_pos >= episode_reward_neg: 158 | episode_reward = episode_reward_pos 159 | opt_policy = opt_policy_pos 160 | else: 161 | episode_reward = episode_reward_neg 162 | opt_policy = opt_policy_neg 163 | 164 | unnorm_episode_reward = episode_reward * (args.max_reward - args.min_reward) + args.min_reward 165 | unnorm_init_episode_reward = init_episode_reward * (args.max_reward - args.min_reward) + args.min_reward 166 | unnorm_decoded_reward = decoded_reward * (args.max_reward - args.min_reward) + args.min_reward 167 | 168 | unnorm_train_rewards[ei].append(unnorm_episode_reward) 169 | train_rewards[ei].append(episode_reward) 170 | if i % args.log_interval == 0: 171 | if 'ant' in args.env_name or 'swimmer' in args.env_name: 172 | print(f"\nTrain Environemnt: {ei} -- top singular value: {s[0].item(): .3f} --- reward after update: {unnorm_episode_reward: .3f}") 173 | print(f"Initial Policy: {pi} --- init true reward: {unnorm_init_episode_reward: .3f} --- decoded: {unnorm_decoded_reward: .3f} --- predicted: {pred_value: .3f}") 174 | print(f"Train Environemnt: {ei} -- top singular value: {s[0].item(): .3f} --- norm reward after update: {episode_reward: .3f}") 175 | print(f"Initial Policy: {pi} --- norm init true reward: {init_episode_reward: .3f} --- norm decoded: {decoded_reward: .3f} --- predicted: {pred_value: .3f}") 176 | 177 | all_mean_rewards[ei].append(np.mean(train_rewards[ei])) 178 | all_mean_unnorm_rewards[ei].append(np.mean(unnorm_train_rewards[ei])) 179 | 180 | for ei in train_envs: 181 | if 'ant' in args.env_name or 'swimmer' in args.env_name: 182 | print("Train Env {} has reward with mean {:.3f} and std {:.3f}"\ 183 | .format(ei, np.mean(all_mean_unnorm_rewards[ei]), np.std(all_mean_unnorm_rewards[ei]))) 184 | else: 185 | print("Train Env {} has reward with mean {:.3f} and std {:.3f}"\ 186 | .format(ei, np.mean(all_mean_rewards[ei]), np.std(all_mean_rewards[ei]))) 187 | 188 | 189 | # Eval on Eval Envs 190 | eval_rewards = {} 191 | unnorm_eval_rewards = {} 192 | for ei in range(len(all_envs)): 193 | eval_rewards[ei] = [] 194 | unnorm_eval_rewards[ei] = [] 195 | for ei in eval_envs: 196 | for i in range(args.num_eval_eps): 197 | args.seed = i 198 | np.random.seed(seed=i) 199 | torch.manual_seed(args.seed) 200 | torch.cuda.manual_seed_all(args.seed) 201 | 202 | for pi in train_policies: 203 | init_obs = torch.FloatTensor(env_sampler.env.reset(env_id=ei)) 204 | if 'ant' in args.env_name or 'swimmer' in args.env_name: 205 | init_state = env_sampler.env.sim.get_state() 206 | res = env_sampler.zeroshot_sample_src_from_pol_state_mujoco(args, init_obs, sizes, policy_idx=pi, env_idx=ei) 207 | else: 208 | init_state = env_sampler.env.state 209 | res = env_sampler.zeroshot_sample_src_from_pol_state(args, init_obs, sizes, policy_idx=pi, env_idx=ei) 210 | 211 | source_env = res['source_env'] 212 | mask_env = res['mask_env'] 213 | source_policy = res['source_policy'] 214 | init_episode_reward = res['episode_reward'] 215 | mask_policy = res['mask_policy'] 216 | 217 | if source_policy.shape[1] == 1: 218 | source_policy = source_policy.repeat(1, 2, 1) 219 | mask_policy = mask_policy.repeat(1, 1, 2) 220 | emb_policy = policy_encoder(source_policy.detach().to(device), 221 | mask_policy.detach().to(device)).detach() 222 | if source_env.shape[1] == 1: 223 | source_env = source_env.repeat(1, 2, 1) 224 | mask_env = mask_env.repeat(1, 1, 2) 225 | emb_env = env_encoder(source_env.detach().to(device), 226 | mask_env.detach().to(device)).detach() 227 | 228 | emb_policy = F.normalize(emb_policy, p=2, dim=1).detach() 229 | emb_env = F.normalize(emb_env, p=2, dim=1).detach() 230 | 231 | pred_value = value_net(init_obs.unsqueeze(0).to(device), 232 | emb_env.to(device), emb_policy.to(device)).item() 233 | if 'ant' in args.env_name or 'swimmer' in args.env_name: 234 | decoded_reward = env_sampler.get_reward_pol_embedding_state_mujoco(args, 235 | init_state, init_obs, emb_policy, policy_decoder, env_idx=ei)[0] 236 | else: 237 | decoded_reward = env_sampler.get_reward_pol_embedding_state(args, 238 | init_state, init_obs, emb_policy, policy_decoder, env_idx=ei)[0] 239 | 240 | qf = value_net.get_qf(init_obs.unsqueeze(0).to(device), emb_env) 241 | u, s, v = torch.svd(qf.squeeze()) 242 | 243 | opt_policy_pos = u[:,0].unsqueeze(0) 244 | opt_policy_neg = -u[:,0].unsqueeze(0) 245 | 246 | if 'ant' in args.env_name or 'swimmer' in args.env_name: 247 | episode_reward_pos, num_steps_pos = env_sampler.get_reward_pol_embedding_state_mujoco( 248 | args, init_state, init_obs, opt_policy_pos, policy_decoder, env_idx=ei) 249 | episode_reward_neg, num_steps_neg = env_sampler.get_reward_pol_embedding_state_mujoco( 250 | args, init_state, init_obs, opt_policy_neg, policy_decoder, env_idx=ei) 251 | else: 252 | episode_reward_pos, num_steps_pos = env_sampler.get_reward_pol_embedding_state( 253 | args, init_state, init_obs, opt_policy_pos, policy_decoder, env_idx=ei) 254 | episode_reward_neg, num_steps_neg = env_sampler.get_reward_pol_embedding_state( 255 | args, init_state, init_obs, opt_policy_neg, policy_decoder, env_idx=ei) 256 | 257 | if episode_reward_pos >= episode_reward_neg: 258 | episode_reward = episode_reward_pos 259 | opt_policy = opt_policy_pos 260 | else: 261 | episode_reward = episode_reward_neg 262 | opt_policy = opt_policy_neg 263 | 264 | unnorm_episode_reward = episode_reward * (args.max_reward - args.min_reward) + args.min_reward 265 | unnorm_init_episode_reward = init_episode_reward * (args.max_reward - args.min_reward) + args.min_reward 266 | unnorm_decoded_reward = decoded_reward * (args.max_reward - args.min_reward) + args.min_reward 267 | 268 | unnorm_eval_rewards[ei].append(unnorm_episode_reward) 269 | eval_rewards[ei].append(episode_reward) 270 | if i % args.log_interval == 0: 271 | if 'ant' in args.env_name or 'swimmer' in args.env_name: 272 | print(f"\nEval Environemnt: {ei} -- top singular value: {s[0].item(): .3f} --- reward after update: {unnorm_episode_reward: .3f}") 273 | print(f"Initial Policy: {pi} --- init true reward: {unnorm_init_episode_reward: .3f} --- decoded: {unnorm_decoded_reward: .3f} --- predicted: {pred_value: .3f}") 274 | print(f"Eval Environemnt: {ei} -- top singular value: {s[0].item(): .3f} --- norm reward after update: {episode_reward: .3f}") 275 | print(f"Initial Policy: {pi} --- norm init true reward: {init_episode_reward: .3f} --- norm decoded: {decoded_reward: .3f} --- predicted: {pred_value: .3f}") 276 | 277 | all_mean_rewards[ei].append(np.mean(eval_rewards[ei])) 278 | all_mean_unnorm_rewards[ei].append(np.mean(unnorm_eval_rewards[ei])) 279 | 280 | for ei in train_envs: 281 | if 'ant' in args.env_name or 'swimmer' in args.env_name: 282 | print("Train Env {} has reward with mean {:.3f} and std {:.3f}"\ 283 | .format(ei, np.mean(all_mean_unnorm_rewards[ei]), np.std(all_mean_unnorm_rewards[ei]))) 284 | else: 285 | print("Train Env {} has reward with mean {:.3f} and std {:.3f}"\ 286 | .format(ei, np.mean(all_mean_rewards[ei]), np.std(all_mean_rewards[ei]))) 287 | 288 | for ei in eval_envs: 289 | if 'ant' in args.env_name or 'swimmer' in args.env_name: 290 | print("Train Env {} has reward with mean {:.3f} and std {:.3f}"\ 291 | .format(ei, np.mean(all_mean_unnorm_rewards[ei]), np.std(all_mean_unnorm_rewards[ei]))) 292 | else: 293 | print("Train Env {} has reward with mean {:.3f} and std {:.3f}"\ 294 | .format(ei, np.mean(all_mean_rewards[ei]), np.std(all_mean_rewards[ei]))) 295 | 296 | env.close() 297 | 298 | if __name__ == '__main__': 299 | eval_pdvf() -------------------------------------------------------------------------------- /figures/ablations_pdvf.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rraileanu/policy-dynamics-value-functions/90c1aa26228103edf8b01d6bcddb70dc37a73e25/figures/ablations_pdvf.png -------------------------------------------------------------------------------- /figures/pdvf_gif.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rraileanu/policy-dynamics-value-functions/90c1aa26228103edf8b01d6bcddb70dc37a73e25/figures/pdvf_gif.png -------------------------------------------------------------------------------- /figures/results_pdvf.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rraileanu/policy-dynamics-value-functions/90c1aa26228103edf8b01d6bcddb70dc37a73e25/figures/results_pdvf.png -------------------------------------------------------------------------------- /myant/myant/__init__.py: -------------------------------------------------------------------------------- 1 | from gym.envs.registration import register 2 | 3 | register( 4 | id='myant-v0', 5 | entry_point='myant.envs:AntEnv', 6 | max_episode_steps=256, 7 | reward_threshold=-3.75, 8 | ) 9 | -------------------------------------------------------------------------------- /myant/myant/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from myant.envs.myant import AntEnv 2 | -------------------------------------------------------------------------------- /myant/myant/envs/myant.py: -------------------------------------------------------------------------------- 1 | import gym 2 | from gym.envs.mujoco import mujoco_env 3 | from gym.envs.mujoco.ant_v3 import AntEnv as AntEnvOrig 4 | 5 | import re 6 | import os 7 | import random 8 | import numpy as np 9 | 10 | from tempfile import mkdtemp 11 | import contextlib 12 | from shutil import copyfile, rmtree 13 | from pathlib import Path 14 | 15 | 16 | @contextlib.contextmanager 17 | def make_temp_directory(prefix=''): 18 | temp_dir = mkdtemp(prefix) 19 | try: 20 | yield temp_dir 21 | finally: 22 | rmtree(temp_dir) 23 | 24 | 25 | class AntEnv(AntEnvOrig): 26 | ''' 27 | Family of Ant environments with different but related dynamics. 28 | ''' 29 | def __init__(self, default_ind=0, num_envs=20, radius=4.0, viscosity=0.05, basepath=None): 30 | self.num_envs = num_envs 31 | AntEnvOrig.__init__(self) 32 | 33 | self.default_params = {'limbs': [.2, .2, .2, .2], 'wind': [0, 0, 0], 'viscosity': 0.0} 34 | self.default_ind = default_ind 35 | 36 | self.env_configs = [] 37 | 38 | for i in range(num_envs): 39 | angle = i * (1/num_envs) * (2*np.pi) 40 | wind_x = radius * np.cos(angle) 41 | wind_y = radius * np.sin(angle) 42 | self.env_configs.append( 43 | {'limbs': [.2, .2, .2, .2], 'wind': [wind_x, wind_y, 0], 'viscosity': viscosity} 44 | ) 45 | 46 | self.env_configs.append(self.default_params) 47 | 48 | self.angle = (self.default_ind + 1) * (1/num_envs) * (2*np.pi) 49 | self.wind_x = radius * np.cos(self.angle) 50 | self.wind_y = radius * np.sin(self.angle) 51 | 52 | self.basepath = basepath 53 | file = open(os.path.join(self.basepath, "ant.xml")) 54 | self.xml = file.readlines() 55 | file.close() 56 | 57 | 58 | def get_xml(self, ind=0): 59 | xmlfile = open(os.path.join(self.basepath, "ant.xml")) 60 | tmp = xmlfile.read() 61 | xmlfile.close() 62 | if ind is None: 63 | ind = self.default_ind 64 | 65 | params = {} 66 | params.update(self.default_params) 67 | params.update(self.env_configs[ind]) 68 | limb_params = params['limbs'] 69 | wind_params = params['wind'] 70 | viscosity = params['viscosity'] 71 | wx, wy, wz = wind_params 72 | 73 | tmp = re.sub("