├── README.md
├── embedding_networks.py
├── env_utils.py
├── eval_pdvf.py
├── figures
├── ablations_pdvf.png
├── pdvf_gif.png
└── results_pdvf.png
├── myant
├── myant
│ ├── __init__.py
│ ├── envs
│ │ ├── __init__.py
│ │ └── myant.py
│ └── utils.py
└── setup.py
├── myspaceship
├── myspaceship
│ ├── __init__.py
│ └── envs
│ │ ├── __init__.py
│ │ └── myspaceship.py
└── setup.py
├── myswimmer
├── myswimmer
│ ├── __init__.py
│ ├── envs
│ │ ├── __init__.py
│ │ └── myswimmer.py
│ └── utils.py
└── setup.py
├── pdvf_arguments.py
├── pdvf_networks.py
├── pdvf_storage.py
├── pdvf_utils.py
├── ppo
├── __init__.py
├── algo
│ ├── __init__.py
│ └── ppo.py
├── arguments.py
├── distributions.py
├── envs.py
├── evaluation.py
├── model.py
├── ppo_main.py
├── storage.py
└── utils.py
├── requirements.txt
├── train_dynamics_embedding.py
├── train_pdvf.py
├── train_policy_embedding.py
└── train_utils.py
/README.md:
--------------------------------------------------------------------------------
1 | # Policy-Dynamics Value Functions (PD-VF)
2 |
3 | This is source code for the paper
4 |
5 | [Fast Adaptation to New Environments via Policy-Dynamics Value Functions](https://arxiv.org/pdf/2007.02879)
6 |
7 | by Roberta Raileanu, Max Goldstein, Arthur Szlam, and Rob Fergus,
8 |
9 | accepted at ICML 2020.
10 |
11 |
12 |
13 | ## Citation
14 | If you use this code in your own work, please cite our paper:
15 | ```
16 | @incollection{icml2020_3993,
17 | abstract = {Standard RL algorithms assume fixed environment dynamics and require a significant amount of interaction to adapt to new environments. We introduce Policy-Dynamics Value Functions (PD-VF), a novel approach for rapidly adapting to dynamics different from those previously seen in training. PD-VF explicitly estimates the cumulative reward in a space of policies and environments. An ensemble of conventional RL policies is used to gather experience on training environments, from which embeddings of both policies and environments can be learned. Then, a value function conditioned on both embeddings is trained. At test time, a few actions are sufficient to infer the environment embedding, enabling a policy to be selected by maximizing the learned value function (which requires no additional environment interaction). We show that our method can rapidly adapt to new dynamics on a set of MuJoCo domains. },
18 | author = {Raileanu, Roberta and Goldstein, Max and Szlam, Arthur and Rob Fergus, Facebook},
19 | booktitle = {Proceedings of Machine Learning and Systems 2020},
20 | pages = {7078--7089},
21 | title = {Fast Adaptation to New Environments via Policy-Dynamics Value Functions},
22 | year = {2020}
23 | }
24 | ```
25 |
26 | ## Requirements
27 | ```
28 | conda create -n pdvf python=3.7
29 | conda activate pdvf
30 |
31 | git clone git@github.com:rraileanu/policy-dynamics-value-functions.git
32 | cd policy-dynamics-value-functions
33 | pip install -r requirements.txt
34 |
35 | cd myant
36 | pip install -e .
37 |
38 | cd myswimmer
39 | pip install -e .
40 |
41 | cd myspaceship
42 | pip install -e .
43 | ```
44 |
45 | ## (1) Reinforcement Learning Phase
46 |
47 | Train PPO policies on each environments, one seed for each.
48 |
49 | Each of the commands below need to be run
50 | for seed in [0,...,4] and for default-ind in [0,...,19].
51 |
52 | ### Spaceship
53 | ```
54 | python ppo/ppo_main.py \
55 | --env-name spaceship-v0 --default-ind 0 --seed 0
56 | ```
57 |
58 | ### Swimmer
59 | ```
60 | python ppo/ppo_main.py \
61 | --env-name myswimmer-v0 --default-ind 0 --seed 0
62 | ```
63 |
64 | ### Ant-wind
65 | ```
66 | python ppo/ppo_main.py \
67 | --env-name myant-v0 --default-ind 0 --seed 0
68 | ```
69 |
70 | ## (2) Self-Supervised Learning Phase
71 |
72 | ## Dynamics Embedding
73 |
74 | ### Spaceship
75 | ```
76 | python train_dynamics_embedding.py \
77 | --env-name spaceship-v0 \
78 | --dynamics-embedding-dim 8 --dynamics-batch-size 8 \
79 | --inf-num-steps 1 --num-dec-traj 10 \
80 | --save-dir-dynamics-embedding ./models/dynamics-embeddings
81 | ```
82 |
83 | ### Swimmer
84 | ```
85 | python train_dynamics_embedding.py \
86 | --env-name myswimmer-v0 \
87 | --dynamics-embedding-dim 2 --dynamics-batch-size 32 \
88 | --inf-num-steps 1 --num-dec-traj 10 \
89 | --save-dir-dynamics-embedding ./models/dynamics-embeddings
90 | ```
91 |
92 | ### Ant-wind
93 | ```
94 | python train_dynamics_embedding.py \
95 | --env-name myant-v0 \
96 | --dynamics-embedding-dim 8 --dynamics-batch-size 32 \
97 | --inf-num-steps 2 --num-dec-traj 10 \
98 | --save-dir-dynamics-embedding ./models/dynamics-embeddings
99 | ```
100 |
101 | ## Policy Embedding
102 |
103 | ### Spaceship
104 | ```
105 | python train_policy_embedding.py \
106 | --env-name spaceship-v0 --num-dec-traj 1 \
107 | --save-dir-policy-embedding ./models/policy-embeddings
108 | ```
109 |
110 | ### Swimmer
111 | ```
112 | python train_policy_embedding.py \
113 | --env-name myswimmer-v0 --num-dec-traj 1 \
114 | --save-dir-policy-embedding ./models/policy-embeddings
115 | ```
116 |
117 | ### Ant-wind
118 | ```
119 | python train_policy_embedding.py \
120 | --env-name myant-v0 --num-dec-traj 1 \
121 | --save-dir-policy-embedding ./models/policy-embeddings
122 | ```
123 |
124 | ## (3) Supervised Learning Phase
125 |
126 | ### Spaceship
127 | ```
128 | python train_pdvf.py \
129 | --env-name spaceship-v0 \
130 | --dynamics-batch-size 8 --policy-batch-size 2048 \
131 | --dynamics-embedding-dim 8 --policy-embedding-dim 8 \
132 | --num-dec-traj 10 --inf-num-steps 1 --log-interval 10 \
133 | --save-dir-dynamics-embedding ./models/dynamics-embeddings \
134 | --save-dir-policy-embedding ./models/policy-embeddings \
135 | --save-dir-pdvf ./models/pdvf-models
136 | ```
137 |
138 | ### Swimmer
139 | ```
140 | python train_pdvf.py \
141 | --env-name myswimmer-v0 \
142 | --dynamics-batch-size 8 --policy-batch-size 2048 \
143 | --dynamics-embedding-dim 2 --policy-embedding-dim 8 \
144 | --num-dec-traj 10 --inf-num-steps 1 --log-interval 10 \
145 | --norm-reward --min-reward -60 --max-reward 200 \
146 | --save-dir-dynamics-embedding ./models/dynamics-embeddings \
147 | --save-dir-policy-embedding ./models/policy-embeddings \
148 | --save-dir-pdvf ./models/pdvf-models
149 | ```
150 |
151 | ### Ant-wind
152 | ```
153 | python train_pdvf.py \
154 | --env-name myant-v0 \
155 | --dynamics-batch-size 32 --policy-batch-size 2048 \
156 | --dynamics-embedding-dim 8 --policy-embedding-dim 8 \
157 | --num-dec-traj 10 --inf-num-steps 2 --log-interval 10 \
158 | --norm-reward --min-reward -400 --max-reward 1000 \
159 | --save-dir-dynamics-embedding ./models/dynamics-embeddings \
160 | --save-dir-policy-embedding ./models/policy-embeddings \
161 | --save-dir-pdvf ./models/pdvf-models
162 | ```
163 |
164 | ## (4) Evaluation Phase
165 |
166 | ### Spaceship
167 | ```
168 | python eval_pdvf.py \
169 | --env-name spaceship-v0 --stage 20 \
170 | --dynamics-batch-size 8 --policy-batch-size 2048 \
171 | --dynamics-embedding-dim 8 --policy-embedding-dim 8 \
172 | --num-dec-traj 10 --inf-num-steps 1 --log-interval 10 \
173 | --save-dir-dynamics-embedding ./models/dynamics-embeddings \
174 | --save-dir-policy-embedding ./models/policy-embeddings \
175 | --save-dir-pdvf ./models/pdvf-models
176 | ```
177 |
178 | ### Swimmer
179 | ```
180 | python eval_pdvf.py \
181 | --env-name myswimmer-v0 --stage 20 \
182 | --dynamics-batch-size 8 --policy-batch-size 2048 \
183 | --dynamics-embedding-dim 2 --policy-embedding-dim 8 \
184 | --num-dec-traj 10 --inf-num-steps 1 --log-interval 10 \
185 | --norm-reward --min-reward -60 --max-reward 200 \
186 | --save-dir-dynamics-embedding ./models/dynamics-embeddings \
187 | --save-dir-policy-embedding ./models/policy-embeddings \
188 | --save-dir-pdvf ./models/pdvf-models
189 | ```
190 |
191 | ### Ant-wind
192 | ```
193 | python eval_pdvf.py \
194 | --env-name myant-v0 --stage 20 \
195 | --dynamics-batch-size 32 --policy-batch-size 2048 \
196 | --dynamics-embedding-dim 8 --policy-embedding-dim 8 \
197 | --num-dec-traj 10 --inf-num-steps 2 --log-interval 10 \
198 | --norm-reward --min-reward -400 --max-reward 1000 \
199 | --save-dir-dynamics-embedding ./models/dynamics-embeddings \
200 | --save-dir-policy-embedding ./models/policy-embeddings \
201 | --save-dir-pdvf ./models/pdvf-models
202 | ```
203 |
204 | ## Results
205 | 
206 |
207 | 
208 |
209 |
210 |
--------------------------------------------------------------------------------
/embedding_networks.py:
--------------------------------------------------------------------------------
1 | """
2 | Models for the policy and dynamics embeddings / autoencoders.
3 |
4 | The encoder is a transformer and the code is built on top
5 | of the following open sourced implementation:
6 | https://github.com/jadore801120/attention-is-all-you-need-pytorch/
7 | """
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | import math, copy
12 | from ppo.utils import init
13 |
14 |
15 | def clones(module, N):
16 | "Produce N identical layers."
17 | return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
18 |
19 |
20 | def init_weights(model):
21 | for p in model.parameters():
22 | if p.dim() > 1:
23 | nn.init.xavier_uniform_(p)
24 |
25 |
26 | def attention(query, key, value, mask=None, dropout=None):
27 | "Compute 'Scaled Dot Product Attention'"
28 | d_k = query.size(-1)
29 | scores = torch.matmul(query, key.transpose(-2, -1)) \
30 | / math.sqrt(d_k)
31 | if mask is not None:
32 | scores = scores.masked_fill(mask == 0, -1e9)
33 | p_attn = F.softmax(scores, dim = -1)
34 | if dropout is not None:
35 | p_attn = dropout(p_attn)
36 | return torch.matmul(p_attn, value), p_attn
37 |
38 |
39 | def make_encoder_oh(input_size, N=1, d_model=320, h=5, dropout=0.1, \
40 | d_emb=8, use_extra_fc=True, no_norm=False):
41 | "Helper: Construct a model from hyperparameters."
42 | c = copy.deepcopy
43 | attn = MultiHeadedAttention(h, d_model)
44 | model = EmbedEncoder(
45 | Encoder(EncoderLayer(d_model, c(attn), dropout, no_norm=no_norm), \
46 | d_model=d_model, d_emb=d_emb, use_extra_fc=use_extra_fc, no_norm=no_norm
47 | ),
48 | LinearEmb(d_model, input_size)
49 | )
50 |
51 | for p in model.parameters():
52 | if p.dim() > 1:
53 | nn.init.xavier_uniform_(p)
54 | return model
55 |
56 |
57 | class EmbedEncoder(nn.Module):
58 | """
59 | A standard Encoder-Decoder architecture. Base for this and many
60 | other models.
61 | """
62 | def __init__(self, encoder, src_embed):
63 | super(EmbedEncoder, self).__init__()
64 | self.encoder = encoder
65 | self.src_embed = src_embed
66 |
67 | def forward(self, src, src_mask):
68 | return self.encoder(self.src_embed(src), src_mask)
69 |
70 |
71 | class LinearEmb(nn.Module):
72 | def __init__(self, d_model, input_size):
73 | super(LinearEmb, self).__init__()
74 | self.lin_emb = nn.Linear(input_size, d_model)
75 |
76 | def forward(self, x):
77 | return self.lin_emb(x.float())
78 |
79 |
80 | class Encoder(nn.Module):
81 | "Core encoder is a stack of N layers"
82 | def __init__(self, layer, d_model=320, d_emb=32, use_extra_fc=False, no_norm=False):
83 | super(Encoder, self).__init__()
84 | self.layer = layer
85 | self.norm = LayerNorm(layer.size)
86 | self.use_extra_fc = use_extra_fc
87 | self.no_norm = no_norm
88 | if self.use_extra_fc:
89 | self.fc_out = nn.Linear(d_model, d_emb)
90 |
91 | def forward(self, x, mask, use_extra_fc=False):
92 | "Pass the input (and mask) through each layer in turn."
93 | x = self.layer(x, mask)
94 | if not self.no_norm:
95 | x = self.norm(x)
96 | if self.use_extra_fc:
97 | x = self.fc_out(x.squeeze(1).mean(1))
98 | return x
99 |
100 |
101 | class LayerNorm(nn.Module):
102 | "Construct a layernorm module (See citation for details)."
103 | def __init__(self, features, eps=1e-6):
104 | super(LayerNorm, self).__init__()
105 | self.a_2 = nn.Parameter(torch.ones(features))
106 | self.b_2 = nn.Parameter(torch.zeros(features))
107 | self.eps = eps
108 |
109 | def forward(self, x):
110 | mean = x.mean(-1, keepdim=True)
111 | std = x.std(-1, keepdim=True)
112 | return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
113 |
114 |
115 | class SublayerConnection(nn.Module):
116 | """
117 | A residual connection followed by a layer norm.
118 | Note for code simplicity the norm is first as opposed to last.
119 | """
120 | def __init__(self, size, dropout):
121 | super(SublayerConnection, self).__init__()
122 | self.norm = LayerNorm(size)
123 | self.dropout = nn.Dropout(dropout)
124 |
125 | def forward(self, x, sublayer):
126 | "Apply residual connection to any sublayer with the same size."
127 | return self.dropout(sublayer(self.norm(x)))
128 |
129 |
130 | class EncoderLayer(nn.Module):
131 | "Encoder is made up of self-attn and feed forward (defined below)"
132 | def __init__(self, size, self_attn, dropout, no_norm=False):
133 | super(EncoderLayer, self).__init__()
134 | self.self_attn = self_attn
135 | self.sublayer = SublayerConnection(size, dropout)
136 | self.size = size
137 | self.no_norm = no_norm
138 |
139 | def forward(self, x, mask):
140 | "Follow Figure 1 (left) for connections."
141 | if self.no_norm:
142 | x = self.self_attn(x, x, x, mask)
143 | else:
144 | x = self.sublayer(x, lambda x: self.self_attn(x, x, x, mask))
145 | return x
146 |
147 |
148 | class MultiHeadedAttention(nn.Module):
149 | def __init__(self, h, d_model, dropout=0.1):
150 | "Take in model size and number of heads."
151 | super(MultiHeadedAttention, self).__init__()
152 | assert d_model % h == 0
153 | # We assume d_v always equals d_k
154 | self.d_k = d_model // h
155 | self.h = h
156 | self.linears = clones(nn.Linear(d_model, d_model), 4)
157 | self.attn = None
158 | self.dropout = nn.Dropout(p=dropout)
159 |
160 | def forward(self, query, key, value, mask=None):
161 | "Implements Figure 2"
162 | if mask is not None:
163 | # Same mask applied to all h heads.
164 | mask = mask.unsqueeze(1)
165 | nbatches = query.size(0)
166 |
167 | # 1) Do all the linear projections in batch from d_model => h x d_k
168 | query, key, value = \
169 | [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
170 | for l, x in zip(self.linears, (query, key, value))]
171 |
172 | # 2) Apply attention on all the projected vectors in batch.
173 | x, self.attn = attention(query, key, value, mask=mask,
174 | dropout=self.dropout)
175 |
176 | # 3) "Concat" using a view and apply a final linear.
177 | x = x.transpose(1, 2).contiguous() \
178 | .view(nbatches, -1, self.h * self.d_k)
179 |
180 | return self.linears[-1](x)
181 |
182 |
183 | class DecoderSpaceship(nn.Module):
184 | def __init__(self, input_size, hidden_size, output_size, device='cuda'):
185 | super(DecoderSpaceship, self).__init__()
186 |
187 | self.device = device
188 | self.input_size = input_size
189 | self.hidden_size = hidden_size
190 | self.output_size = output_size
191 |
192 | self.fc_in = nn.Linear(input_size, hidden_size)
193 | self.fc_mid = nn.Linear(hidden_size, hidden_size)
194 |
195 | self.fc_out_x = nn.Linear(hidden_size, output_size)
196 | self.fc_out_y = nn.Linear(hidden_size, output_size)
197 |
198 | def forward(self, input, time=None):
199 | input = torch.cat((input, time), dim=1)
200 | output = F.relu(self.fc_in(input))
201 | output = F.relu(self.fc_mid(output))
202 |
203 | params_x = self.fc_out_x(output)
204 | params_y = self.fc_out_y(output)
205 |
206 | shape = params_x.shape[0]
207 | act_x = torch.randn(shape).to(self.device)*params_x[:, 1] + params_x[:, 0]
208 | act_y = torch.randn(shape).to(self.device)*params_y[:, 1] + params_y[:, 0]
209 |
210 | act = torch.cat((act_x.unsqueeze(1), act_y.unsqueeze(1)), dim=1)
211 |
212 | return act
213 |
214 |
215 | class DecoderMujoco(nn.Module):
216 | def __init__(self, input_size, hidden_size, output_size, device='cuda'):
217 | super(DecoderMujoco, self).__init__()
218 |
219 | self.device = device
220 | self.input_size = input_size
221 | self.hidden_size = hidden_size
222 | self.output_size = output_size
223 |
224 | self.fc_in = nn.Linear(input_size, hidden_size)
225 | self.fc_mid = nn.Linear(hidden_size, hidden_size)
226 |
227 | self.fc_out = nn.Linear(hidden_size, output_size)
228 |
229 | def forward(self, input, time=None):
230 | input = torch.cat((input, time), dim=1)
231 | output = F.relu(self.fc_in(input))
232 | output = F.relu(self.fc_mid(output))
233 | next_state = self.fc_out(output)
234 | return next_state
235 |
236 |
237 |
--------------------------------------------------------------------------------
/env_utils.py:
--------------------------------------------------------------------------------
1 | import gym
2 | import myspaceship
3 | import torch
4 | import torch.nn.functional as F
5 |
6 |
7 | MAX_EPISODE_STEPS_ANT = 256
8 | MAX_EPISODE_STEPS_SWIMMER = 1000
9 | MAX_EPISODE_STEPS_SPACESHIP = 50
10 |
11 |
12 | def make_one_hot(x, nb_digits=3, batch_size=1):
13 | '''
14 | Convert int to one hot tensor
15 | '''
16 | y = x.reshape(batch_size, 1)
17 |
18 | # One hot encoding buffer that you create out of the loop and just keep reusing
19 | y_onehot = torch.FloatTensor(batch_size, nb_digits)
20 |
21 | # In your for loop
22 | y_onehot.zero_()
23 | y_onehot.scatter_(1, y.long(), 1)
24 |
25 | return y_onehot
26 |
27 |
28 | class EnvSamplerEmb():
29 | '''
30 | Environment sampler object for training the
31 | policy and dynamics embeddings.
32 | '''
33 | def __init__(self, env, base_policy, args):
34 | if args.env_name.startswith('spaceship'):
35 | self.env = env.venv.envs[0]
36 | elif args.env_name.startswith('myswimmer'):
37 | self.env = env.venv.envs[0].env.env
38 | else:
39 | self.env = env.venv.envs[0].env.env.env
40 | self.env.reset()
41 | self.env1 = env
42 | self.base_policy = base_policy
43 | self.args = args
44 | self.env_name = args.env_name
45 | self.action_dim = self.env1.action_space.shape[0]
46 | self.state_dim = self.env1.observation_space.shape[0]
47 | self.enc_input_size = 2*self.state_dim + self.action_dim
48 | self.inf_num_steps = args.inf_num_steps
49 | self.env_max_seq_length = args.inf_num_steps * self.enc_input_size
50 | self.max_seq_length = args.max_num_steps * self.enc_input_size
51 |
52 | if 'ant' in args.env_name:
53 | self.max_episode_steps = MAX_EPISODE_STEPS_ANT
54 | elif 'swimmer' in args.env_name:
55 | self.max_episode_steps = MAX_EPISODE_STEPS_SWIMMER
56 | elif 'spaceship' in args.env_name:
57 | self.max_episode_steps = MAX_EPISODE_STEPS_SPACESHIP
58 |
59 | def reset(self):
60 | '''
61 | Reset the environment to an environment that might have different dynamics.
62 | '''
63 | self.env1.reset()
64 |
65 | def reset_same(self):
66 | '''
67 | Reset the environment to an environment with identical dynamics.
68 | '''
69 | return self.env.reset(same=True)
70 |
71 | def sample_env_context(self, policy_fn, env_idx=None):
72 | '''
73 | Generate a few steps in the new environment.
74 | Use these transitions to infer the environment's dynamics.
75 | '''
76 | src = []
77 | this_src_batch = []
78 | if env_idx is not None:
79 | state = self.env.reset(env_id=env_idx)
80 | else:
81 | state = self.reset_same()
82 | state = torch.tensor(state).float()
83 | done = False
84 |
85 | max_num_steps = self.inf_num_steps
86 |
87 | for t in range(max_num_steps):
88 | recurrent_hidden_state = torch.zeros(
89 | 1, policy_fn.recurrent_hidden_state_size, device=self.args.device)
90 | mask = torch.zeros(1, 1, device=self.args.device)
91 | policy_fn = policy_fn.float()
92 | action = policy_fn.act(
93 | state.squeeze().unsqueeze(0).to(torch.device(self.args.device)).float(),
94 | recurrent_hidden_state.float(), mask.float(),
95 | deterministic=True)[1].detach()
96 |
97 | action_tensor = action.float().reshape(self.action_dim)
98 | state_action_tensor = torch.cat((state.squeeze().to(device=self.args.device),
99 | action_tensor.reshape(self.action_dim)), dim=0)
100 | if not done:
101 | next_state, _, done, _ = self.env.step(*action.cpu().numpy())
102 | next_state_tensor = torch.FloatTensor(next_state)
103 | state = next_state_tensor
104 | state_action_state_tensor = torch.cat((state_action_tensor.to(device=self.args.device), \
105 | next_state_tensor.to(device=self.args.device)), dim=0).unsqueeze(1)
106 | else:
107 | state_action_state_tensor = torch.FloatTensor([0 for _ in range(self.enc_input_size)])\
108 | .reshape(self.enc_input_size, 1)
109 | src.append(state_action_state_tensor)
110 | src_pad = F.pad(torch.stack(src).reshape((t + 1) * self.enc_input_size), \
111 | (0, self.max_seq_length - (t + 1) * len(state_action_state_tensor)))
112 |
113 | this_src_batch.append(src_pad)
114 |
115 | return this_src_batch[-1]
116 |
117 | def sample_policy(self, policy_idx=0):
118 | '''
119 | Sample a policy from your set of pretrained policies.
120 | '''
121 | return self.base_policy[policy_idx]
122 |
123 | def generic_step(self, policy_fn, state):
124 | '''
125 | Take a step in the environment
126 | '''
127 | action = policy_fn(state.squeeze().unsqueeze(0))
128 | action_tensor = torch.FloatTensor(action)
129 | state_action_tensor = torch.cat((state.squeeze(),
130 | action_tensor.squeeze()), dim=0)
131 |
132 | next_state, reward, done, _ = self.env.step(action.numpy())
133 | next_state_tensor = torch.Tensor(next_state).float()
134 | if not done:
135 | sas_tensor = torch.cat((state_action_tensor, next_state_tensor), dim=0).unsqueeze(1)
136 | else:
137 | sas_tensor = torch.FloatTensor([0 for _ in range(self.enc_input_size)]) \
138 | .reshape(self.enc_input_size, 1)
139 |
140 | res = {'state': next_state_tensor, 'action': action_tensor, 'sa_tensor': state_action_tensor,
141 | 'sas_tensor': sas_tensor, 'reward': reward, 'done': done}
142 | return res
143 |
144 | def get_decoded_traj(self, args, init_state, init_obs, policy_emb, decoder, env_idx=0, verbose=False):
145 | '''
146 | Decode a trajectory using the policy decoder conditioned on a given policy embedding,
147 | in a given environment for a given initial state. Works with the Spaceship environment.
148 | '''
149 | state = self.env.reset(env_id=env_idx)
150 | state = torch.FloatTensor(state)
151 |
152 | device = args.device
153 | done = False
154 | episode_reward = 0
155 |
156 | all_emb_state = []
157 | all_recurrent_state = []
158 | all_mask = []
159 | all_action = []
160 | for t in range(args.max_num_steps):
161 | recurrent_hidden_state = torch.zeros(policy_emb.shape[0],
162 | decoder.recurrent_hidden_state_size, device=device, requires_grad=True).float()
163 | mask_dec = torch.zeros(policy_emb.shape[0], 1, device=device, requires_grad=True).float()
164 | emb_state_input = torch.cat((policy_emb.to(device), state.unsqueeze(0).to(device)), dim=1).to(device)
165 |
166 | action = decoder.act(emb_state_input, recurrent_hidden_state, mask_dec,
167 | deterministic=True)[1]
168 |
169 | action_flat = action.squeeze().cpu().detach().numpy()
170 | action = action.cpu().detach().numpy()
171 |
172 | all_emb_state.append(emb_state_input)
173 | all_recurrent_state.append(recurrent_hidden_state)
174 | all_mask.append(mask_dec)
175 | all_action.append(torch.tensor(action).to(device))
176 |
177 | next_state, reward, done, _ = self.env.step(action_flat)
178 | episode_reward = args.gamma * episode_reward + reward
179 | state = torch.FloatTensor(next_state)
180 | if done:
181 | break
182 |
183 | return all_emb_state, all_recurrent_state, all_mask, all_action
184 |
185 | def get_decoded_traj_mujoco(self, args, init_state, init_obs, policy_emb, decoder, env_idx=0, verbose=False):
186 | '''
187 | Decode a trajectory using the policy decoder conditioned on a given policy embedding,
188 | in a given environment for a given initial state. Works with MuJoCo environments (Ant and Swimmer).
189 | '''
190 | self.env.reset(env_id=env_idx)
191 | self.env.sim.set_state(init_state)
192 | state = init_obs
193 |
194 | device = args.device
195 | done = False
196 | episode_reward = 0
197 | t = 0
198 |
199 | all_emb_state = []
200 | all_recurrent_state = []
201 | all_mask = []
202 | all_action = []
203 | for t in range(self.max_episode_steps):
204 | t += 1
205 | recurrent_hidden_state = torch.zeros(policy_emb.shape[0],
206 | decoder.recurrent_hidden_state_size, device=device, requires_grad=True).float()
207 | mask_dec = torch.zeros(policy_emb.shape[0], 1, device=device, requires_grad=True).float()
208 | emb_state_input = torch.cat((policy_emb.to(device), state.unsqueeze(0).to(device)), dim=1).to(device)
209 | action = decoder.act(emb_state_input, recurrent_hidden_state, mask_dec,
210 | deterministic=True)[1]
211 |
212 | action_flat = action.squeeze().cpu().detach().numpy()
213 | action = action.cpu().detach().numpy()
214 |
215 | all_emb_state.append(emb_state_input)
216 | all_recurrent_state.append(recurrent_hidden_state)
217 | all_mask.append(mask_dec)
218 | all_action.append(torch.tensor(action).to(device))
219 |
220 | next_state, reward, done, _ = self.env.step(action_flat)
221 | episode_reward = args.gamma * episode_reward + reward
222 | state = torch.FloatTensor(next_state)
223 | if done:
224 | break
225 |
226 | if args.norm_reward:
227 | episode_reward = (episode_reward - args.min_reward) / (args.max_reward - args.min_reward)
228 |
229 | return all_emb_state, all_recurrent_state, all_mask, all_action
230 |
231 | def sample_policy_data(self, policy_idx=0, env_idx=None):
232 | '''
233 | Sample data using a given policy.
234 | '''
235 | done = False
236 | state_batch = []
237 | tgt_batch = []
238 | src_batch = []
239 | mask_batch = []
240 | mask_batch_all = []
241 | src = []
242 | masks = []
243 |
244 | if env_idx is not None:
245 | init_state = self.env.reset(env_id=env_idx)
246 | else:
247 | init_state = self.env.reset()
248 |
249 | trajectory = self.sample_policy(policy_idx=policy_idx)
250 | state_tensor = torch.tensor(init_state).unsqueeze(0)
251 | recurrent_hidden_state = torch.zeros(
252 | 1, trajectory.recurrent_hidden_state_size, device=self.args.device)
253 | mask = torch.zeros(1, 1, device=self.args.device)
254 | trajectory = trajectory.float().to(self.args.device)
255 | action = trajectory.act(
256 | state_tensor.to(torch.device(self.args.device)).float(),
257 | recurrent_hidden_state.float(), mask.float(),
258 | deterministic=True)[1]
259 | action_tensor = action.float().reshape(self.action_dim)
260 | state_action_tensor = torch.cat([
261 | state_tensor.to(torch.device(self.args.device)).float().squeeze(),
262 | action_tensor], dim=0).unsqueeze(1)
263 |
264 | for t in range(self.args.max_num_steps):
265 | state_batch.append(state_tensor)
266 | src.append(state_action_tensor)
267 | masks.append(torch.FloatTensor([done == False]))
268 | mask_batch.append(torch.FloatTensor([done == False]))
269 |
270 | state_tensor, _, done, _ = self.env.step(action_tensor.cpu().detach().numpy())
271 |
272 | tgt_batch.append(action_tensor.detach())
273 |
274 |
275 | state_tensor = torch.tensor(state_tensor).unsqueeze(0)
276 | recurrent_hidden_state = torch.zeros(
277 | 1, trajectory.recurrent_hidden_state_size, device=self.args.device)
278 | mask = torch.zeros(1, 1, device=self.args.device)
279 | trajectory = trajectory.float().to(self.args.device)
280 | action = trajectory.act(
281 | state_tensor.to(torch.device(self.args.device)).float(),
282 | recurrent_hidden_state.float(), mask.float(),
283 | deterministic=True)[1].detach()
284 | action_tensor = action.float().reshape(self.action_dim)
285 | state_action_tensor = torch.cat([
286 | state_tensor.to(torch.device(self.args.device)).float().squeeze(),
287 | action_tensor], dim=0).unsqueeze(1)
288 |
289 | for t in range(self.args.max_num_steps):
290 | src_tensor = torch.stack(src).squeeze(2)
291 | src_batch.append(src_tensor)
292 | mask_batch_all.append(torch.stack(masks))
293 |
294 | return state_batch, tgt_batch, src_batch, mask_batch, mask_batch_all
295 |
296 | def sample_k_traj_zeroshot(self, k_traj, policy_idx=0, env_idx=None):
297 | '''
298 | Sample a number of trajectories using the inferred dynamics
299 | from a small number of interactions with a new environment.
300 | '''
301 | trajectory = self.sample_policy(policy_idx=policy_idx)
302 | context_env = self.sample_env_context(trajectory, env_idx=env_idx)
303 | state_action_list = []
304 | target_list = []
305 | source_list = []
306 |
307 | for k in range(k_traj):
308 | eval_episode_rewards = []
309 | init_state = torch.tensor(self.env.reset(same=True)).float()
310 | state = init_state
311 | obs = init_state
312 | trajectory = self.sample_policy(policy_idx=policy_idx)
313 | for _ in range(self.inf_num_steps):
314 | obs_feat = obs.float().squeeze().reshape(1,-1).to(torch.device(self.args.device))
315 |
316 | recurrent_hidden_state = torch.zeros(
317 | 1, trajectory.recurrent_hidden_state_size, device=self.args.device)
318 | mask = torch.zeros(1, 1, device=self.args.device)
319 | trajectory = trajectory.float()
320 | action = trajectory.act(
321 | obs_feat, recurrent_hidden_state.float(), mask.float(),
322 | deterministic=True)[1].detach()
323 |
324 | action_tensor = action.float().reshape(self.action_dim)
325 | state_action_tensor = torch.cat([obs.squeeze().to(device=self.args.device),
326 | action_tensor.to(device=self.args.device)], dim=0)
327 | obs, reward, done, infos = self.env.step(*action.cpu().numpy())
328 |
329 | obs = torch.tensor(obs).float()
330 | target_tensor = obs
331 |
332 | state_action_list.append(state_action_tensor)
333 | target_list.append(target_tensor)
334 | source_list.append(context_env.reshape(self.args.max_num_steps, -1))
335 |
336 | return state_action_list, target_list, source_list
337 |
338 |
339 | class EnvSamplerPDVF():
340 | '''
341 | Environment sampler object for training the
342 | Policy-Dynamics Value Function.
343 | '''
344 | def __init__(self, env, base_policy, args):
345 | if args.env_name.startswith('spaceship'):
346 | self.env = env.venv.envs[0]
347 | elif args.env_name.startswith('myswimmer'):
348 | self.env = env.venv.envs[0].env.env
349 | else:
350 | self.env = env.venv.envs[0].env.env.env
351 | self.env.reset()
352 | self.env1 = env
353 | self.base_policy = base_policy
354 | self.args = args
355 | self.env_name = args.env_name
356 | self.action_dim = self.env1.action_space.shape[0]
357 | self.state_dim = self.env1.observation_space.shape[0]
358 | self.enc_input_size = 2 * self.state_dim + self.action_dim
359 | self.inf_num_steps = args.inf_num_steps
360 | self.env_max_seq_length = args.inf_num_steps * self.enc_input_size
361 | self.max_seq_length = args.max_num_steps * self.enc_input_size
362 |
363 | if 'ant' in args.env_name:
364 | self.max_episode_steps = MAX_EPISODE_STEPS_ANT
365 | elif 'swimmer' in args.env_name:
366 | self.max_episode_steps = MAX_EPISODE_STEPS_SWIMMER
367 | elif 'spaceship' in args.env_name:
368 | self.max_episode_steps = MAX_EPISODE_STEPS_SPACESHIP
369 |
370 | def reset(self):
371 | '''
372 | Reset the environment to an environment that might have different dynamics.
373 | '''
374 | return self.env1.reset()
375 |
376 | def reset_same(self):
377 | '''
378 | Reset the environment to an environment with identical dynamics.
379 | '''
380 | return self.env.reset(same=True)
381 |
382 | def sample_policy(self, policy_idx=0):
383 | '''
384 | Sample a policy from your set of pretrained policies.
385 | '''
386 | return self.base_policy[policy_idx]
387 |
388 | def generic_step(self, policy_fn, state):
389 | '''
390 | Take a step in the environment
391 | '''
392 | recurrent_hidden_state = torch.zeros(
393 | 1, policy_fn.recurrent_hidden_state_size, device=self.args.device)
394 | mask = torch.zeros(1, 1, device=self.args.device)
395 | policy_fn = policy_fn.float().to(self.args.device)
396 | action = policy_fn.act(
397 | state.squeeze().unsqueeze(0).to(torch.device(self.args.device)).float(),
398 | recurrent_hidden_state.float(), mask.float(),
399 | deterministic=True)[1].detach()
400 | action_tensor = action.float().reshape(self.action_dim)
401 | state_action_tensor = torch.cat([state.float().squeeze().to(self.args.device),
402 | action_tensor.to(self.args.device)], dim=0)
403 | next_state, reward, done, _ = self.env.step(*action.cpu().numpy())
404 |
405 | next_state_tensor = torch.Tensor(next_state).float().to(self.args.device)
406 | if not done:
407 | sas_tensor = torch.cat((state_action_tensor, next_state_tensor), dim=0).unsqueeze(1)
408 | else:
409 | sas_tensor = torch.FloatTensor([0 for _ in range(self.enc_input_size)]) \
410 | .reshape(self.enc_input_size, 1)
411 |
412 | res = {'next_state': next_state_tensor, 'action': action_tensor, 'sa_tensor': state_action_tensor,
413 | 'sas_tensor': sas_tensor.to(self.args.device), 'reward': reward, 'done': done}
414 | return res
415 |
416 | def sample_env_context(self, policy_fn, env_idx=None):
417 | '''
418 | Generate a few steps in the new environment.
419 | Use these transitions to infer the environment's dynamics.
420 | '''
421 | src = []
422 | this_src_batch = []
423 | if env_idx is not None:
424 | state = self.env.reset(env_id=env_idx)
425 | else:
426 | state = self.reset_same()
427 | state = torch.tensor(state).float()
428 | done = False
429 |
430 | max_num_steps = self.inf_num_steps
431 |
432 | for t in range(max_num_steps):
433 | res = self.generic_step(policy_fn, state)
434 | state, done, sas_tensor = res['next_state'], res['done'], res['sas_tensor']
435 | src.append(sas_tensor)
436 | src_pad = F.pad(torch.stack(src).reshape((t + 1) * self.enc_input_size), \
437 | (0, self.env_max_seq_length - (t + 1) * len(sas_tensor)))
438 |
439 | this_src_batch.append(src_pad)
440 | if done:
441 | break
442 |
443 | res = this_src_batch[-1].reshape(max_num_steps, -1)
444 | return res
445 |
446 | def zeroshot_sample_src_from_pol_state(self, args, init_state, sizes, policy_idx=0, env_idx=0):
447 | '''
448 | Sample transitions using a certain policy and starting in a given state.
449 | Works for Spaceship.
450 | '''
451 | # get the policy embedding
452 | src_policy = []
453 | masks_policy = []
454 |
455 | episode_reward = 0
456 | state = init_state
457 | policy_fn = self.sample_policy(policy_idx=policy_idx)
458 | for t in range(args.max_num_steps):
459 | res = self.generic_step(policy_fn, state)
460 | state, done, reward, sa_tensor, sas_tensor = res['next_state'], res['done'], res['reward'], \
461 | res['sa_tensor'], res['sas_tensor']
462 | episode_reward = args.gamma * episode_reward + reward
463 | src_policy.append(sa_tensor)
464 | masks_policy.append(torch.FloatTensor([done == False]))
465 | if done:
466 | break
467 |
468 | policy_feats = torch.stack(src_policy).unsqueeze(0)
469 | mask_policy = torch.stack(masks_policy).squeeze(1).unsqueeze(0).unsqueeze(0)
470 |
471 | if self.env_name.startswith('myacrobot'):
472 | episode_reward += args.max_num_steps
473 |
474 | # get the env embedding
475 | src_env = []
476 | env_feats = []
477 |
478 | state = init_state
479 | policy_fn = self.sample_policy(policy_idx=policy_idx)
480 | for t in range(self.inf_num_steps):
481 | res = self.generic_step(policy_fn, state)
482 | state, done, reward, sa_tensor, sas_tensor = res['next_state'], res['done'], res['reward'], \
483 | res['sa_tensor'], res['sas_tensor']
484 | env_feats.append(sas_tensor)
485 | env_pad = F.pad(torch.stack(env_feats).reshape((t+1) * self.enc_input_size), \
486 | (0, self.env_max_seq_length - (t+1) * len(sas_tensor)))
487 | if done:
488 | break
489 |
490 | source_env = torch.stack([env_pad.reshape(self.inf_num_steps,len(env_feats[0]))])
491 | mask_env = (source_env != 0).unsqueeze(-2)
492 | mask_env = mask_env[:, :, :, 0].squeeze(2).unsqueeze(1)
493 |
494 | res = {'source_env': source_env,
495 | 'source_policy': policy_feats,
496 | 'mask_policy': mask_policy,
497 | 'mask_env': mask_env,
498 | 'episode_reward': episode_reward,
499 | 't': t,
500 | 'init_state': init_state}
501 |
502 | return res
503 |
504 | def zeroshot_sample_src_from_pol_state_mujoco(self, args, init_state, sizes, policy_idx=0, env_idx=0):
505 | '''
506 | Sample transitions using a certain policy and starting in a given state.
507 | Works for Swimmer and Ant.
508 | '''
509 | # get the policy embedding
510 | src_policy = []
511 | masks_policy = []
512 |
513 | episode_reward = 0
514 | state = init_state
515 | policy_fn = self.sample_policy(policy_idx=policy_idx)
516 | for t in range(args.max_num_steps):
517 | res = self.generic_step(policy_fn, state)
518 | state, done, reward, sa_tensor, sas_tensor = res['next_state'], res['done'], res['reward'], \
519 | res['sa_tensor'], res['sas_tensor']
520 | episode_reward = args.gamma * episode_reward + reward
521 | src_policy.append(sa_tensor)
522 | masks_policy.append(torch.FloatTensor([done == False]))
523 | if done:
524 | break
525 | if not done:
526 | for t in range(self.max_episode_steps - args.max_num_steps):
527 | res = self.generic_step(policy_fn, state)
528 | state, done, reward, sa_tensor, sas_tensor = res['next_state'], res['done'], res['reward'], \
529 | res['sa_tensor'], res['sas_tensor']
530 | episode_reward = args.gamma * episode_reward + reward
531 | if done:
532 | break
533 |
534 | policy_feats = torch.stack(src_policy).unsqueeze(0)
535 | mask_policy = torch.stack(masks_policy).squeeze(1).unsqueeze(0).unsqueeze(0)
536 |
537 | # get the env embedding
538 | src_env = []
539 | env_feats = []
540 |
541 | state = init_state
542 | policy_fn = self.sample_policy(policy_idx=policy_idx)
543 | for t in range(self.inf_num_steps):
544 | res = self.generic_step(policy_fn, state)
545 | state, done, reward, sa_tensor, sas_tensor = res['next_state'], res['done'], res['reward'], \
546 | res['sa_tensor'], res['sas_tensor']
547 | env_feats.append(sas_tensor)
548 | env_pad = F.pad(torch.stack(env_feats).reshape((t+1) * self.enc_input_size), \
549 | (0, self.env_max_seq_length - (t+1) * len(sas_tensor)))
550 | if done:
551 | break
552 |
553 | source_env = torch.stack([env_pad.reshape(self.inf_num_steps, len(env_feats[0]))])
554 | mask_env = (source_env != 0).unsqueeze(-2)
555 | mask_env = mask_env[:, :, :, 0].squeeze(2).unsqueeze(1)
556 |
557 | if args.norm_reward:
558 | episode_reward = (episode_reward - args.min_reward) / (args.max_reward - args.min_reward)
559 |
560 | res = {'source_env': source_env,
561 | 'source_policy': policy_feats,
562 | 'mask_policy': mask_policy,
563 | 'mask_env': mask_env,
564 | 'episode_reward': episode_reward,
565 | 't': t,
566 | 'init_state': init_state}
567 |
568 | return res
569 |
570 | def get_reward_pol_embedding_state(self, args, init_state, init_obs, policy_emb, decoder, env_idx=0, verbose=False):
571 | '''
572 | Estimate the return using Monte-Carlo for a given policy embedding starting at a given initial state.
573 | Works for Spaceship.
574 | '''
575 | self.env.reset(env_id=env_idx)
576 | self.env.state = init_state
577 | state = init_obs
578 |
579 | device = args.device
580 | done = False
581 | episode_reward = 0
582 | for t in range(args.max_num_steps):
583 | recurrent_hidden_state = torch.zeros(policy_emb.shape[0],
584 | decoder.recurrent_hidden_state_size, device=device, requires_grad=True).float()
585 | mask_dec = torch.zeros(policy_emb.shape[0], 1, device=device, requires_grad=True).float()
586 | emb_state_input = torch.cat((policy_emb.to(device), state.unsqueeze(0).to(device)), dim=1).to(device)
587 | action = decoder.act(emb_state_input, recurrent_hidden_state, mask_dec,
588 | deterministic=True)[1].squeeze().cpu().detach().numpy()
589 | next_state, reward, done, _ = self.env.step(action)
590 |
591 | episode_reward = args.gamma * episode_reward + reward
592 | state = torch.FloatTensor(next_state)
593 | if done:
594 | break
595 | return episode_reward, t
596 |
597 | def get_reward_pol_embedding_state_mujoco(self, args, init_state, init_obs, policy_emb, decoder, env_idx=0, verbose=False):
598 | '''
599 | Estimate the return using Monte-Carlo for a given policy embedding starting at a given initial state.
600 | Works for Swimmer and Ant.
601 | '''
602 | self.env.reset(env_id=env_idx)
603 | self.env.sim.set_state(init_state)
604 | state = init_obs
605 |
606 | device = args.device
607 | done = False
608 | episode_reward = 0
609 | t = 0
610 | for t in range(self.max_episode_steps):
611 | t += 1
612 | recurrent_hidden_state = torch.zeros(policy_emb.shape[0],
613 | decoder.recurrent_hidden_state_size, device=device, requires_grad=True).float()
614 | mask_dec = torch.zeros(policy_emb.shape[0], 1, device=device, requires_grad=True).float()
615 | emb_state_input = torch.cat((policy_emb.to(device), state.unsqueeze(0).to(device)), dim=1).to(device)
616 | action = decoder.act(emb_state_input, recurrent_hidden_state, mask_dec,
617 | deterministic=True)[1].squeeze().cpu().detach().numpy()
618 | next_state, reward, done, _ = self.env.step(action)
619 |
620 | episode_reward = args.gamma * episode_reward + reward
621 | state = torch.FloatTensor(next_state)
622 | if done:
623 | break
624 |
625 | if args.norm_reward:
626 | episode_reward = (episode_reward - args.min_reward) / (args.max_reward - args.min_reward)
627 |
628 | return episode_reward, t
629 |
630 | def sample_policy_data(self, policy_idx=0, env_idx=None):
631 | '''
632 | Sample transitions from a given policy in your collection.
633 | '''
634 | state_batch = []
635 | tgt_batch = []
636 | src_batch = []
637 | mask_batch = []
638 | mask_batch_all = []
639 |
640 | if env_idx is not None:
641 | state = self.env.reset(env_id=env_idx)
642 | else:
643 | state = self.env.reset()
644 |
645 | trajectory = self.sample_policy(policy_idx=policy_idx)
646 | state = torch.tensor(state).unsqueeze(0)
647 | res = self.generic_step(trajectory, state)
648 | state, action, done, reward, sa_tensor = res['next_state'], res['action'], res['done'], res['reward'], \
649 | res['sa_tensor']
650 | sa_tensor = torch.transpose(sa_tensor, 1, 0)
651 | src = []
652 | masks = []
653 |
654 | for t in range(self.args.max_num_steps):
655 | state_batch.append(state)
656 | src.append(sa_tensor)
657 | tgt_batch.append(action.detach())
658 | res = self.generic_step(trajectory, state)
659 | state, action, done, reward, sa_tensor = res['state'], res['action'], res['done'], res['reward'], \
660 | res['sa_tensor']
661 |
662 | sa_tensor = torch.transpose(sa_tensor, 1, 0)
663 | masks.append(torch.FloatTensor([done == False]))
664 | mask_batch.append(torch.FloatTensor([done == False]))
665 |
666 | for t in range(self.args.max_num_steps):
667 | src_tensor = torch.stack(src).squeeze(2)
668 | src_batch.append(src_tensor)
669 | mask_batch_all.append(torch.stack(masks))
670 |
671 | return state_batch, tgt_batch, src_batch, mask_batch, mask_batch_all
672 |
--------------------------------------------------------------------------------
/eval_pdvf.py:
--------------------------------------------------------------------------------
1 | import os, random, sys
2 | import numpy as np
3 |
4 | import torch
5 | import torch.optim as optim
6 | import torch.nn.functional as F
7 |
8 | from pdvf_networks import PDVF
9 |
10 | from pdvf_arguments import get_args
11 |
12 | from ppo.model import Policy
13 | from ppo.envs import make_vec_envs
14 |
15 | import env_utils
16 | import pdvf_utils
17 | import train_utils
18 |
19 | import myant
20 | import myswimmer
21 | import myspaceship
22 |
23 |
24 | def eval_pdvf():
25 | '''
26 | Evaluate the Policy-Dynamics Value Function.
27 | '''
28 | args = get_args()
29 |
30 | torch.manual_seed(args.seed)
31 | torch.cuda.manual_seed_all(args.seed)
32 |
33 | torch.set_num_threads(1)
34 | device = args.device
35 | if device != 'cpu':
36 | torch.cuda.empty_cache()
37 |
38 | if args.cuda and torch.cuda.is_available() and args.cuda_deterministic:
39 | torch.backends.cudnn.benchmark = False
40 | torch.backends.cudnn.deterministic = True
41 |
42 | env = make_vec_envs(args, device)
43 | env.reset()
44 |
45 | names = []
46 | for e in range(args.num_envs):
47 | for s in range(args.num_seeds):
48 | names.append('ppo.{}.env{}.seed{}.pt'.format(args.env_name, e, s))
49 |
50 | source_policy = []
51 | for name in names:
52 | actor_critic = Policy(
53 | env.observation_space.shape,
54 | env.action_space,
55 | base_kwargs={'recurrent': False})
56 | actor_critic.to(device)
57 | model = os.path.join(args.save_dir, name)
58 | actor_critic.load_state_dict(torch.load(model))
59 | source_policy.append(actor_critic)
60 |
61 | # Load the collected interaction episodes for each agent
62 | policy_encoder, policy_decoder = pdvf_utils.load_policy_model(
63 | args, env)
64 | env_encoder = pdvf_utils.load_dynamics_model(
65 | args, env)
66 |
67 | value_net = PDVF(env.observation_space.shape[0], args.dynamics_embedding_dim, args.hidden_dim_pdvf,
68 | args.policy_embedding_dim, device=device).to(device)
69 | value_net.to(device)
70 | path_to_pdvf = os.path.join(args.save_dir_pdvf, \
71 | "pdvf-stage{}.{}.pt".format(args.stage, args.env_name))
72 | value_net.load_state_dict(torch.load(path_to_pdvf)['state_dict'])
73 | value_net.eval()
74 |
75 | all_envs = [i for i in range(args.num_envs)]
76 | train_policies = [i for i in range(int(3/4*args.num_envs))]
77 | train_envs = [i for i in range(int(3/4*args.num_envs))]
78 | eval_envs = [i for i in range(int(3/4*args.num_envs), args.num_envs)]
79 |
80 | env_enc_input_size = env.observation_space.shape[0] + env.action_space.shape[0]
81 | sizes = pdvf_utils.DotDict({'state_dim': env.observation_space.shape[0], \
82 | 'action_dim': env.action_space.shape[0], 'env_enc_input_size': \
83 | env_enc_input_size, 'env_max_seq_length': args.max_num_steps * env_enc_input_size})
84 |
85 | env_sampler = env_utils.EnvSamplerPDVF(env, source_policy, args)
86 |
87 | all_mean_rewards = [[] for _ in range(args.num_envs)]
88 | all_mean_unnorm_rewards = [[] for _ in range(args.num_envs)]
89 |
90 | # Eval on Train Envs
91 | train_rewards = {}
92 | unnorm_train_rewards = {}
93 | for ei in range(len(all_envs)):
94 | train_rewards[ei] = []
95 | unnorm_train_rewards[ei] = []
96 | for ei in train_envs:
97 | for i in range(args.num_eval_eps):
98 | args.seed = i
99 | np.random.seed(seed=i)
100 | torch.manual_seed(args.seed)
101 | torch.cuda.manual_seed_all(args.seed)
102 | for pi in train_policies:
103 | init_obs = torch.FloatTensor(env_sampler.env.reset(env_id=ei))
104 | if 'ant' in args.env_name or 'swimmer' in args.env_name:
105 | init_state = env_sampler.env.sim.get_state()
106 | res = env_sampler.zeroshot_sample_src_from_pol_state_mujoco(args, init_obs, sizes, policy_idx=pi, env_idx=ei)
107 | else:
108 | init_state = env_sampler.env.state
109 | res = env_sampler.zeroshot_sample_src_from_pol_state(args, init_obs, sizes, policy_idx=pi, env_idx=ei)
110 |
111 | source_env = res['source_env']
112 | mask_env = res['mask_env']
113 | source_policy = res['source_policy']
114 | init_episode_reward = res['episode_reward']
115 | mask_policy = res['mask_policy']
116 |
117 | if source_policy.shape[1] == 1:
118 | source_policy = source_policy.repeat(1, 2, 1)
119 | mask_policy = mask_policy.repeat(1, 1, 2)
120 | emb_policy = policy_encoder(source_policy.detach().to(device),
121 | mask_policy.detach().to(device)).detach()
122 | if source_env.shape[1] == 1:
123 | source_env = source_env.repeat(1, 2, 1)
124 | mask_env = mask_env.repeat(1, 1, 2)
125 | emb_env = env_encoder(source_env.detach().to(device),
126 | mask_env.detach().to(device)).detach()
127 |
128 | emb_policy = F.normalize(emb_policy, p=2, dim=1).detach()
129 | emb_env = F.normalize(emb_env, p=2, dim=1).detach()
130 |
131 | pred_value = value_net(init_obs.unsqueeze(0).to(device),
132 | emb_env.to(device), emb_policy.to(device)).item()
133 | if 'ant' in args.env_name or 'swimmer' in args.env_name:
134 | decoded_reward = env_sampler.get_reward_pol_embedding_state_mujoco(args,
135 | init_state, init_obs, emb_policy, policy_decoder, env_idx=ei)[0]
136 | else:
137 | decoded_reward = env_sampler.get_reward_pol_embedding_state(args,
138 | init_state, init_obs, emb_policy, policy_decoder, env_idx=ei)[0]
139 |
140 | qf = value_net.get_qf(init_obs.unsqueeze(0).to(device), emb_env)
141 | u, s, v = torch.svd(qf.squeeze())
142 |
143 | opt_policy_pos = u[:,0].unsqueeze(0)
144 | opt_policy_neg = -u[:,0].unsqueeze(0)
145 |
146 | if 'ant' in args.env_name or 'swimmer' in args.env_name:
147 | episode_reward_pos, num_steps_pos = env_sampler.get_reward_pol_embedding_state_mujoco(
148 | args, init_state, init_obs, opt_policy_pos, policy_decoder, env_idx=ei)
149 | episode_reward_neg, num_steps_neg = env_sampler.get_reward_pol_embedding_state_mujoco(
150 | args, init_state, init_obs, opt_policy_neg, policy_decoder, env_idx=ei)
151 | else:
152 | episode_reward_pos, num_steps_pos = env_sampler.get_reward_pol_embedding_state(
153 | args, init_state, init_obs, opt_policy_pos, policy_decoder, env_idx=ei)
154 | episode_reward_neg, num_steps_neg = env_sampler.get_reward_pol_embedding_state(
155 | args, init_state, init_obs, opt_policy_neg, policy_decoder, env_idx=ei)
156 |
157 | if episode_reward_pos >= episode_reward_neg:
158 | episode_reward = episode_reward_pos
159 | opt_policy = opt_policy_pos
160 | else:
161 | episode_reward = episode_reward_neg
162 | opt_policy = opt_policy_neg
163 |
164 | unnorm_episode_reward = episode_reward * (args.max_reward - args.min_reward) + args.min_reward
165 | unnorm_init_episode_reward = init_episode_reward * (args.max_reward - args.min_reward) + args.min_reward
166 | unnorm_decoded_reward = decoded_reward * (args.max_reward - args.min_reward) + args.min_reward
167 |
168 | unnorm_train_rewards[ei].append(unnorm_episode_reward)
169 | train_rewards[ei].append(episode_reward)
170 | if i % args.log_interval == 0:
171 | if 'ant' in args.env_name or 'swimmer' in args.env_name:
172 | print(f"\nTrain Environemnt: {ei} -- top singular value: {s[0].item(): .3f} --- reward after update: {unnorm_episode_reward: .3f}")
173 | print(f"Initial Policy: {pi} --- init true reward: {unnorm_init_episode_reward: .3f} --- decoded: {unnorm_decoded_reward: .3f} --- predicted: {pred_value: .3f}")
174 | print(f"Train Environemnt: {ei} -- top singular value: {s[0].item(): .3f} --- norm reward after update: {episode_reward: .3f}")
175 | print(f"Initial Policy: {pi} --- norm init true reward: {init_episode_reward: .3f} --- norm decoded: {decoded_reward: .3f} --- predicted: {pred_value: .3f}")
176 |
177 | all_mean_rewards[ei].append(np.mean(train_rewards[ei]))
178 | all_mean_unnorm_rewards[ei].append(np.mean(unnorm_train_rewards[ei]))
179 |
180 | for ei in train_envs:
181 | if 'ant' in args.env_name or 'swimmer' in args.env_name:
182 | print("Train Env {} has reward with mean {:.3f} and std {:.3f}"\
183 | .format(ei, np.mean(all_mean_unnorm_rewards[ei]), np.std(all_mean_unnorm_rewards[ei])))
184 | else:
185 | print("Train Env {} has reward with mean {:.3f} and std {:.3f}"\
186 | .format(ei, np.mean(all_mean_rewards[ei]), np.std(all_mean_rewards[ei])))
187 |
188 |
189 | # Eval on Eval Envs
190 | eval_rewards = {}
191 | unnorm_eval_rewards = {}
192 | for ei in range(len(all_envs)):
193 | eval_rewards[ei] = []
194 | unnorm_eval_rewards[ei] = []
195 | for ei in eval_envs:
196 | for i in range(args.num_eval_eps):
197 | args.seed = i
198 | np.random.seed(seed=i)
199 | torch.manual_seed(args.seed)
200 | torch.cuda.manual_seed_all(args.seed)
201 |
202 | for pi in train_policies:
203 | init_obs = torch.FloatTensor(env_sampler.env.reset(env_id=ei))
204 | if 'ant' in args.env_name or 'swimmer' in args.env_name:
205 | init_state = env_sampler.env.sim.get_state()
206 | res = env_sampler.zeroshot_sample_src_from_pol_state_mujoco(args, init_obs, sizes, policy_idx=pi, env_idx=ei)
207 | else:
208 | init_state = env_sampler.env.state
209 | res = env_sampler.zeroshot_sample_src_from_pol_state(args, init_obs, sizes, policy_idx=pi, env_idx=ei)
210 |
211 | source_env = res['source_env']
212 | mask_env = res['mask_env']
213 | source_policy = res['source_policy']
214 | init_episode_reward = res['episode_reward']
215 | mask_policy = res['mask_policy']
216 |
217 | if source_policy.shape[1] == 1:
218 | source_policy = source_policy.repeat(1, 2, 1)
219 | mask_policy = mask_policy.repeat(1, 1, 2)
220 | emb_policy = policy_encoder(source_policy.detach().to(device),
221 | mask_policy.detach().to(device)).detach()
222 | if source_env.shape[1] == 1:
223 | source_env = source_env.repeat(1, 2, 1)
224 | mask_env = mask_env.repeat(1, 1, 2)
225 | emb_env = env_encoder(source_env.detach().to(device),
226 | mask_env.detach().to(device)).detach()
227 |
228 | emb_policy = F.normalize(emb_policy, p=2, dim=1).detach()
229 | emb_env = F.normalize(emb_env, p=2, dim=1).detach()
230 |
231 | pred_value = value_net(init_obs.unsqueeze(0).to(device),
232 | emb_env.to(device), emb_policy.to(device)).item()
233 | if 'ant' in args.env_name or 'swimmer' in args.env_name:
234 | decoded_reward = env_sampler.get_reward_pol_embedding_state_mujoco(args,
235 | init_state, init_obs, emb_policy, policy_decoder, env_idx=ei)[0]
236 | else:
237 | decoded_reward = env_sampler.get_reward_pol_embedding_state(args,
238 | init_state, init_obs, emb_policy, policy_decoder, env_idx=ei)[0]
239 |
240 | qf = value_net.get_qf(init_obs.unsqueeze(0).to(device), emb_env)
241 | u, s, v = torch.svd(qf.squeeze())
242 |
243 | opt_policy_pos = u[:,0].unsqueeze(0)
244 | opt_policy_neg = -u[:,0].unsqueeze(0)
245 |
246 | if 'ant' in args.env_name or 'swimmer' in args.env_name:
247 | episode_reward_pos, num_steps_pos = env_sampler.get_reward_pol_embedding_state_mujoco(
248 | args, init_state, init_obs, opt_policy_pos, policy_decoder, env_idx=ei)
249 | episode_reward_neg, num_steps_neg = env_sampler.get_reward_pol_embedding_state_mujoco(
250 | args, init_state, init_obs, opt_policy_neg, policy_decoder, env_idx=ei)
251 | else:
252 | episode_reward_pos, num_steps_pos = env_sampler.get_reward_pol_embedding_state(
253 | args, init_state, init_obs, opt_policy_pos, policy_decoder, env_idx=ei)
254 | episode_reward_neg, num_steps_neg = env_sampler.get_reward_pol_embedding_state(
255 | args, init_state, init_obs, opt_policy_neg, policy_decoder, env_idx=ei)
256 |
257 | if episode_reward_pos >= episode_reward_neg:
258 | episode_reward = episode_reward_pos
259 | opt_policy = opt_policy_pos
260 | else:
261 | episode_reward = episode_reward_neg
262 | opt_policy = opt_policy_neg
263 |
264 | unnorm_episode_reward = episode_reward * (args.max_reward - args.min_reward) + args.min_reward
265 | unnorm_init_episode_reward = init_episode_reward * (args.max_reward - args.min_reward) + args.min_reward
266 | unnorm_decoded_reward = decoded_reward * (args.max_reward - args.min_reward) + args.min_reward
267 |
268 | unnorm_eval_rewards[ei].append(unnorm_episode_reward)
269 | eval_rewards[ei].append(episode_reward)
270 | if i % args.log_interval == 0:
271 | if 'ant' in args.env_name or 'swimmer' in args.env_name:
272 | print(f"\nEval Environemnt: {ei} -- top singular value: {s[0].item(): .3f} --- reward after update: {unnorm_episode_reward: .3f}")
273 | print(f"Initial Policy: {pi} --- init true reward: {unnorm_init_episode_reward: .3f} --- decoded: {unnorm_decoded_reward: .3f} --- predicted: {pred_value: .3f}")
274 | print(f"Eval Environemnt: {ei} -- top singular value: {s[0].item(): .3f} --- norm reward after update: {episode_reward: .3f}")
275 | print(f"Initial Policy: {pi} --- norm init true reward: {init_episode_reward: .3f} --- norm decoded: {decoded_reward: .3f} --- predicted: {pred_value: .3f}")
276 |
277 | all_mean_rewards[ei].append(np.mean(eval_rewards[ei]))
278 | all_mean_unnorm_rewards[ei].append(np.mean(unnorm_eval_rewards[ei]))
279 |
280 | for ei in train_envs:
281 | if 'ant' in args.env_name or 'swimmer' in args.env_name:
282 | print("Train Env {} has reward with mean {:.3f} and std {:.3f}"\
283 | .format(ei, np.mean(all_mean_unnorm_rewards[ei]), np.std(all_mean_unnorm_rewards[ei])))
284 | else:
285 | print("Train Env {} has reward with mean {:.3f} and std {:.3f}"\
286 | .format(ei, np.mean(all_mean_rewards[ei]), np.std(all_mean_rewards[ei])))
287 |
288 | for ei in eval_envs:
289 | if 'ant' in args.env_name or 'swimmer' in args.env_name:
290 | print("Train Env {} has reward with mean {:.3f} and std {:.3f}"\
291 | .format(ei, np.mean(all_mean_unnorm_rewards[ei]), np.std(all_mean_unnorm_rewards[ei])))
292 | else:
293 | print("Train Env {} has reward with mean {:.3f} and std {:.3f}"\
294 | .format(ei, np.mean(all_mean_rewards[ei]), np.std(all_mean_rewards[ei])))
295 |
296 | env.close()
297 |
298 | if __name__ == '__main__':
299 | eval_pdvf()
--------------------------------------------------------------------------------
/figures/ablations_pdvf.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rraileanu/policy-dynamics-value-functions/90c1aa26228103edf8b01d6bcddb70dc37a73e25/figures/ablations_pdvf.png
--------------------------------------------------------------------------------
/figures/pdvf_gif.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rraileanu/policy-dynamics-value-functions/90c1aa26228103edf8b01d6bcddb70dc37a73e25/figures/pdvf_gif.png
--------------------------------------------------------------------------------
/figures/results_pdvf.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rraileanu/policy-dynamics-value-functions/90c1aa26228103edf8b01d6bcddb70dc37a73e25/figures/results_pdvf.png
--------------------------------------------------------------------------------
/myant/myant/__init__.py:
--------------------------------------------------------------------------------
1 | from gym.envs.registration import register
2 |
3 | register(
4 | id='myant-v0',
5 | entry_point='myant.envs:AntEnv',
6 | max_episode_steps=256,
7 | reward_threshold=-3.75,
8 | )
9 |
--------------------------------------------------------------------------------
/myant/myant/envs/__init__.py:
--------------------------------------------------------------------------------
1 | from myant.envs.myant import AntEnv
2 |
--------------------------------------------------------------------------------
/myant/myant/envs/myant.py:
--------------------------------------------------------------------------------
1 | import gym
2 | from gym.envs.mujoco import mujoco_env
3 | from gym.envs.mujoco.ant_v3 import AntEnv as AntEnvOrig
4 |
5 | import re
6 | import os
7 | import random
8 | import numpy as np
9 |
10 | from tempfile import mkdtemp
11 | import contextlib
12 | from shutil import copyfile, rmtree
13 | from pathlib import Path
14 |
15 |
16 | @contextlib.contextmanager
17 | def make_temp_directory(prefix=''):
18 | temp_dir = mkdtemp(prefix)
19 | try:
20 | yield temp_dir
21 | finally:
22 | rmtree(temp_dir)
23 |
24 |
25 | class AntEnv(AntEnvOrig):
26 | '''
27 | Family of Ant environments with different but related dynamics.
28 | '''
29 | def __init__(self, default_ind=0, num_envs=20, radius=4.0, viscosity=0.05, basepath=None):
30 | self.num_envs = num_envs
31 | AntEnvOrig.__init__(self)
32 |
33 | self.default_params = {'limbs': [.2, .2, .2, .2], 'wind': [0, 0, 0], 'viscosity': 0.0}
34 | self.default_ind = default_ind
35 |
36 | self.env_configs = []
37 |
38 | for i in range(num_envs):
39 | angle = i * (1/num_envs) * (2*np.pi)
40 | wind_x = radius * np.cos(angle)
41 | wind_y = radius * np.sin(angle)
42 | self.env_configs.append(
43 | {'limbs': [.2, .2, .2, .2], 'wind': [wind_x, wind_y, 0], 'viscosity': viscosity}
44 | )
45 |
46 | self.env_configs.append(self.default_params)
47 |
48 | self.angle = (self.default_ind + 1) * (1/num_envs) * (2*np.pi)
49 | self.wind_x = radius * np.cos(self.angle)
50 | self.wind_y = radius * np.sin(self.angle)
51 |
52 | self.basepath = basepath
53 | file = open(os.path.join(self.basepath, "ant.xml"))
54 | self.xml = file.readlines()
55 | file.close()
56 |
57 |
58 | def get_xml(self, ind=0):
59 | xmlfile = open(os.path.join(self.basepath, "ant.xml"))
60 | tmp = xmlfile.read()
61 | xmlfile.close()
62 | if ind is None:
63 | ind = self.default_ind
64 |
65 | params = {}
66 | params.update(self.default_params)
67 | params.update(self.env_configs[ind])
68 | limb_params = params['limbs']
69 | wind_params = params['wind']
70 | viscosity = params['viscosity']
71 | wx, wy, wz = wind_params
72 |
73 | tmp = re.sub("