├── expert
├── assets
│ └── expert_traj
│ │ ├── Hopper-v2_expert_traj.p
│ │ └── Swimmer-v2_expert_traj.p
├── sac_args.py
├── save_traj_ppo.py
├── save_traj_sac.py
├── train_sac.py
└── train_ppo.py
├── utils
├── __init__.py
├── tools.py
├── math.py
├── zfilter.py
├── replay_memory.py
├── render.py
├── torch.py
└── utils.py
├── envs
├── test_envs.py
├── mujoco
│ ├── __init__.py
│ ├── assets
│ │ ├── light_swimmer.xml
│ │ ├── swimmer.xml
│ │ ├── heavy_swimmer.xml
│ │ ├── disabled_swimmer.xml
│ │ ├── ant.xml
│ │ ├── light_ant.xml
│ │ ├── heavy_ant.xml
│ │ └── disabled_ant.xml
│ ├── swimmer.py
│ └── ant.py
├── __init__.py
├── cycle_4room.py
├── mountaincar.py
└── fourroom.py
├── core
├── common.py
├── ppo.py
└── agent.py
├── LICENSE
├── models
├── rnd_model.py
├── WGAN.py
├── VAE.py
├── sac_models.py
├── ppo_models.py
└── dynamics.py
├── .gitignore
├── README.md
├── sail.py
└── agents
├── sac_agent.py
└── soft_bc_agent.py
/expert/assets/expert_traj/Hopper-v2_expert_traj.p:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FangchenLiu/SAIL/HEAD/expert/assets/expert_traj/Hopper-v2_expert_traj.p
--------------------------------------------------------------------------------
/expert/assets/expert_traj/Swimmer-v2_expert_traj.p:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FangchenLiu/SAIL/HEAD/expert/assets/expert_traj/Swimmer-v2_expert_traj.p
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from utils.replay_memory import *
2 | from utils.zfilter import *
3 | from utils.torch import *
4 | from utils.math import *
5 | from utils.tools import *
6 |
--------------------------------------------------------------------------------
/utils/tools.py:
--------------------------------------------------------------------------------
1 | from os import path
2 | import torch
3 |
4 | def assets_dir():
5 | return path.abspath(path.join(path.dirname(path.abspath(__file__)), '../expert/assets'))
6 |
7 | def swish(x):
8 | return x * torch.sigmoid(x)
9 |
--------------------------------------------------------------------------------
/utils/math.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 |
4 |
5 | def normal_entropy(std):
6 | var = std.pow(2)
7 | entropy = 0.5 + 0.5 * torch.log(2 * var * math.pi)
8 | return entropy.sum(1, keepdim=True)
9 |
10 |
11 | def normal_log_density(x, mean, log_std, std):
12 | var = std.pow(2)
13 | log_density = -(x - mean).pow(2) / (2 * var) - 0.5 * math.log(2 * math.pi) - log_std
14 | return log_density.sum(1, keepdim=True)
15 |
--------------------------------------------------------------------------------
/envs/test_envs.py:
--------------------------------------------------------------------------------
1 | from utils.render import play
2 | import gym
3 | import envs
4 | import envs.mujoco
5 |
6 | if __name__ == '__main__':
7 | env = gym.make('DisableAnt-v0')
8 | play(env, None, None, video_path='disable_ant.avi', time_limit=200, device='cpu')
9 | env = gym.make('LightAnt-v0')
10 | play(env, None, None, video_path='light_ant.avi', time_limit=200, device='cpu')
11 | env = gym.make('HeavyAnt-v0')
12 | play(env, None, None, video_path='heavy_ant.avi', time_limit=200, device='cpu')
--------------------------------------------------------------------------------
/core/common.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from utils import to_device
3 |
4 |
5 | def estimate_advantages(rewards, masks, values, gamma, tau, device):
6 | rewards, masks, values = to_device(torch.device('cpu'), rewards, masks, values)
7 | tensor_type = type(rewards)
8 | deltas = tensor_type(rewards.size(0), 1)
9 | advantages = tensor_type(rewards.size(0), 1)
10 |
11 | prev_value = 0
12 | prev_advantage = 0
13 | for i in reversed(range(rewards.size(0))):
14 | deltas[i] = rewards[i] + gamma * prev_value * masks[i] - values[i]
15 | advantages[i] = deltas[i] + gamma * tau * prev_advantage * masks[i]
16 |
17 | prev_value = values[i, 0]
18 | prev_advantage = advantages[i, 0]
19 |
20 | returns = values + advantages
21 | advantages = (advantages - advantages.mean()) / advantages.std()
22 |
23 | advantages, returns = to_device(device, advantages, returns)
24 | return advantages, returns
--------------------------------------------------------------------------------
/envs/mujoco/__init__.py:
--------------------------------------------------------------------------------
1 | from gym.envs.registration import register
2 |
3 | register(
4 | id='DisableAnt-v0',
5 | entry_point='envs.mujoco.ant:DisableAntEnv',
6 | max_episode_steps=1000,
7 | reward_threshold=6000.0,
8 | )
9 |
10 | register(
11 | id='HeavyAnt-v0',
12 | entry_point='envs.mujoco.ant:HeavyAntEnv',
13 | max_episode_steps=1000,
14 | reward_threshold=6000.0,
15 | )
16 |
17 | register(
18 | id='LightAnt-v0',
19 | entry_point='envs.mujoco.ant:LightAntEnv',
20 | max_episode_steps=1000,
21 | reward_threshold=6000.0,
22 | )
23 |
24 | register(
25 | id='DisableSwimmer-v0',
26 | entry_point='envs.mujoco.swimmer:DisableSwimmerEnv',
27 | max_episode_steps=1000,
28 | reward_threshold=360.0,
29 | )
30 | register(
31 | id='LightSwimmer-v0',
32 | entry_point='envs.mujoco.swimmer:LightSwimmerEnv',
33 | max_episode_steps=1000,
34 | reward_threshold=360.0,
35 | )
36 | register(
37 | id='HeavySwimmer-v0',
38 | entry_point='envs.mujoco.swimmer:HeavySwimmerEnv',
39 | max_episode_steps=1000,
40 | reward_threshold=360.0,
41 | )
--------------------------------------------------------------------------------
/envs/__init__.py:
--------------------------------------------------------------------------------
1 | from gym.envs.registration import register
2 | register(
3 | id='Fourroom-v0',
4 | entry_point='envs.fourroom:FourRoom',
5 | kwargs={'goal_type':'fix_goal'},
6 | max_episode_steps=200,
7 | reward_threshold=100.0,
8 | nondeterministic=False,
9 | )
10 | register(
11 | id='Fourroom-v1',
12 | entry_point='envs.fourroom:FourRoom1',
13 | kwargs={'goal_type':'fix_goal'},
14 | max_episode_steps=200,
15 | reward_threshold=100.0,
16 | nondeterministic=False,
17 | )
18 | register(
19 | id='CycleFourroom-v0',
20 | entry_point='envs.cycle_4room:FourRoom',
21 | max_episode_steps=200,
22 | reward_threshold=100.0,
23 | nondeterministic=False,
24 | )
25 | register(
26 | id='CycleFourroom-v1',
27 | entry_point='envs.cycle_4room:FourRoom1',
28 | max_episode_steps=200,
29 | reward_threshold=100.0,
30 | nondeterministic=False,
31 | )
32 | register(
33 | id='MyMountainCar-v0',
34 | entry_point='envs.mountaincar:Continuous_MountainCarEnv',
35 | max_episode_steps=999,
36 | reward_threshold=90.0,
37 | nondeterministic=False,
38 | )
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 FangchenLiu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/core/ppo.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def ppo_step(policy_net, value_net, optimizer_policy, optimizer_value, optim_value_iternum, states, actions,
5 | returns, advantages, fixed_log_probs, clip_epsilon, l2_reg):
6 |
7 | """update critic"""
8 | for _ in range(optim_value_iternum):
9 | values_pred = value_net(states)
10 | value_loss = (values_pred - returns).pow(2).mean()
11 | # weight decay
12 | for param in value_net.parameters():
13 | value_loss += param.pow(2).sum() * l2_reg
14 | optimizer_value.zero_grad()
15 | value_loss.backward()
16 | optimizer_value.step()
17 |
18 | """update policy"""
19 | log_probs = policy_net.get_log_prob(states, actions)
20 | ratio = torch.exp(log_probs - fixed_log_probs)
21 | surr1 = ratio * advantages
22 | surr2 = torch.clamp(ratio, 1.0 - clip_epsilon, 1.0 + clip_epsilon) * advantages
23 | policy_surr = -torch.min(surr1, surr2).mean()
24 | optimizer_policy.zero_grad()
25 | policy_surr.backward()
26 | torch.nn.utils.clip_grad_norm_(policy_net.parameters(), 40)
27 | optimizer_policy.step()
28 |
--------------------------------------------------------------------------------
/models/rnd_model.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | import torch
3 | from torch.nn import init
4 | import numpy as np
5 |
6 | class Coskx(nn.Module):
7 | def __init__(self, k=50):
8 | super(Coskx, self).__init__()
9 | self.k = k
10 | def forward(self, input):
11 | return torch.cos(input * self.k)
12 |
13 | class RND(nn.Module):
14 | def __init__(self, d, last_size=512, ptb=10):
15 | super(RND, self).__init__()
16 | self.target = nn.Sequential(
17 | nn.Linear(d, 512),
18 | nn.LeakyReLU(),
19 | nn.Linear(512, 512),
20 | nn.LeakyReLU(),
21 | Coskx(100),
22 | nn.Linear(512, last_size)
23 | )
24 | self.predictor = nn.Sequential(
25 | nn.Linear(d, 512),
26 | nn.LeakyReLU(),
27 | nn.Linear(512, 512),
28 | nn.LeakyReLU(),
29 | nn.Linear(512, 512),
30 | nn.LeakyReLU(),
31 | nn.Linear(512, last_size)
32 | )
33 | for p in self.modules():
34 | if isinstance(p, nn.Linear):
35 | init.orthogonal_(p.weight, np.sqrt(2))
36 |
37 | for param in self.target.parameters():
38 | param.requires_grad = False
39 |
40 | def forward(self, states):
41 | target_feature = self.target(states).detach()
42 | predict_feature = self.predictor(states)
43 | return ((predict_feature - target_feature) ** 2).mean(-1)
44 |
45 | def get_q(self, states):
46 | err = self.forward(states)
47 | return torch.exp(-10*err).detach()
48 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Created by .ignore support plugin (hsz.mobi)
2 | ### JetBrains template
3 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm
4 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
5 |
6 | # User-specific stuff:
7 | .idea/**/workspace.xml
8 | .idea/**/tasks.xml
9 | .idea/dictionaries
10 | ./gail/load_expert_traj.py
11 | # Sensitive or high-churn files:
12 | .idea/**/dataSources/
13 | .idea/**/dataSources.ids
14 | .idea/**/dataSources.xml
15 | .idea/**/dataSources.local.xml
16 | .idea/**/sqlDataSources.xml
17 | .idea/**/dynamic.xml
18 | .idea/**/uiDesigner.xml
19 |
20 | # Gradle:
21 | .idea/**/gradle.xml
22 | .idea/**/libraries
23 |
24 | # CMake
25 | cmake-build-debug/
26 | cmake-build-release/
27 |
28 | # Mongo Explorer plugin:
29 | .idea/**/mongoSettings.xml
30 |
31 | ## File-based project format:
32 | *.iws
33 |
34 | ## Plugin-specific files:
35 |
36 | # IntelliJ
37 | out/
38 | runs/
39 |
40 | # mpeltonen/sbt-idea plugin
41 | .idea_modules/
42 |
43 | # JIRA plugin
44 | atlassian-ide-plugin.xml
45 |
46 | # Cursive Clojure plugin
47 | .idea/replstate.xml
48 |
49 | # Crashlytics plugin (for Android Studio and IntelliJ)
50 | com_crashlytics_export_strings.xml
51 | crashlytics.properties
52 | crashlytics-build.properties
53 | fabric.properties
54 |
55 | .idea/handful-of-trials-minimal-cartpole.iml
56 | .idea/inspectionProfiles/
57 | .idea/markdown-navigator/
58 | .idea/misc.xml
59 | .idea/modules.xml
60 | .idea/workspace.xml
61 | .idea/
62 | __pycache__/
63 | learned_models/
64 | log/
65 | *.lprof
66 | .DS_Store
67 | .pt
68 | .p
69 | *expert_traj.p
70 |
--------------------------------------------------------------------------------
/utils/zfilter.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | # from https://github.com/joschu/modular_rl
4 | # http://www.johndcook.com/blog/standard_deviation/
5 |
6 |
7 | class RunningStat(object):
8 | def __init__(self, shape):
9 | self._n = 0
10 | self._M = np.zeros(shape)
11 | self._S = np.zeros(shape)
12 |
13 | def push(self, x):
14 | x = np.asarray(x)
15 | assert x.shape == self._M.shape
16 | self._n += 1
17 | if self._n == 1:
18 | self._M[...] = x
19 | else:
20 | oldM = self._M.copy()
21 | self._M[...] = oldM + (x - oldM) / self._n
22 | self._S[...] = self._S + (x - oldM) * (x - self._M)
23 |
24 | @property
25 | def n(self):
26 | return self._n
27 |
28 | @property
29 | def mean(self):
30 | return self._M
31 |
32 | @property
33 | def var(self):
34 | return self._S / (self._n - 1) if self._n > 1 else np.square(self._M)
35 |
36 | @property
37 | def std(self):
38 | return np.sqrt(self.var)
39 |
40 | @property
41 | def shape(self):
42 | return self._M.shape
43 |
44 |
45 | class ZFilter:
46 | """
47 | y = (x-mean)/std
48 | using running estimates of mean,std
49 | """
50 |
51 | def __init__(self, shape, demean=True, destd=True, clip=10.0):
52 | self.demean = demean
53 | self.destd = destd
54 | self.clip = clip
55 | self.rs = RunningStat(shape)
56 |
57 | def __call__(self, x, update=True):
58 | if update:
59 | self.rs.push(x)
60 | if self.demean:
61 | x = x - self.rs.mean
62 | if self.destd:
63 | x = x / (self.rs.std + 1e-8)
64 | if self.clip:
65 | x = np.clip(x, -self.clip, self.clip)
66 | return x
67 |
--------------------------------------------------------------------------------
/utils/replay_memory.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 | import random
3 | import numpy as np
4 |
5 |
6 | Transition = namedtuple('Transition', ('state', 'action', 'reward', 'next_state',
7 | 'mask'))
8 |
9 | class Memory:
10 | def __init__(self, capacity):
11 | self.capacity = int(capacity)
12 | self.buffer = []
13 | self.position = 0
14 |
15 | def push(self, state, action, reward, next_state, done):
16 | if len(self.buffer) < self.capacity:
17 | self.buffer.append(None)
18 | self.buffer[self.position] = (state, action, reward, next_state, done)
19 | self.position = (self.position + 1) % self.capacity
20 |
21 | def sample(self, batch_size):
22 | batch = random.sample(self.buffer, batch_size)
23 | #print(batch)
24 | state, action, reward, next_state, done = map(np.stack, zip(*batch))
25 | return state, action, reward, next_state, done
26 |
27 | def sample_all(self):
28 | batch = self.buffer
29 | state, action, reward, next_state, done = map(np.stack, zip(*batch))
30 | return state, action, reward, next_state, done
31 |
32 | def __len__(self):
33 | return len(self.buffer)
34 |
35 | class OnlineMemory(object):
36 | def __init__(self):
37 | self.memory = []
38 |
39 | def push(self, *args):
40 | """Saves a transition."""
41 | self.memory.append(Transition(*args))
42 |
43 | def sample(self, batch_size=None):
44 | if batch_size is None:
45 | return Transition(*zip(*self.memory))
46 | else:
47 | random_batch = random.sample(self.memory, batch_size)
48 | return Transition(*zip(*random_batch))
49 |
50 | def append(self, new_memory):
51 | self.memory += new_memory.memory
52 |
53 | def __len__(self):
54 | return len(self.memory)
55 |
56 | def clear(self):
57 | self.memory = []
58 |
59 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # State Alignment-based Imitation Learning
2 | We propose a state-based imitation learning method for cross-morphology imitation learning, by considering both the state visitation distribution and local transition alignment.
3 |
4 | ### ATTENTION: The code is in a pre-release stage and under maintenance. Some details may be different from the paper or not included, due to internal integration for other projects. The author will reorganize some parts to match the original paper asap.
5 |
6 | ## Train
7 | The demonstration is provided in ``./expert/assets/expert_traj``. They are obtained from well-trained expert algorithms (SAC or TRPO).
8 |
9 | The cross-morphology imitator can be founded in ``./envs/mujoco/assets``. You can also customize your own environment for training.
10 |
11 | The main algorithm contains two settings: original imitation learning and cross-morphology imitation learning.
12 | The only difference is the pretraining stage for the inverse dynamics model.
13 |
14 | The format should be
15 | ```bash
16 | python sail.py --env-name [YOUR-ENV-NAME] --expert-traj-path [PATH-TO-DEMO] --beta 0.01 --resume [if want resume] --transfer [if cross morphology]
17 | ```
18 | For example, for original hopper imitation:
19 | ```bash
20 | python sail.py --env-name Hopper-v2 --expert-traj-path ./expert/assets/expert_traj/Hopper-v2_expert_traj.p --beta 0.005
21 | ```
22 | for disabled swimmer imitation
23 | ```bash
24 | python sail.py --env-name DisableSwimmer-v0 --expert-traj-path ./expert/assets/expert_traj/Swimmer-v2_expert_traj.p --beta 0.005 --transfer
25 | ```
26 |
27 | ## Cite Our Paper
28 | If you find it useful, please consider to cite our paper.
29 | ```
30 | @article{liu2019state,
31 | title={State Alignment-based Imitation Learning},
32 | author={Liu, Fangchen and Ling, Zhan and Mu, Tongzhou and Su, Hao},
33 | journal={arXiv preprint arXiv:1911.10947},
34 | year={2019}
35 | }
36 | ```
37 |
38 | ## Demonstrations
39 | Please download [here](https://drive.google.com/open?id=1cIqYevPDE2_06Elyo_UqEfHiQvOKP6in) and put it to ```./expert/assets/expert_traj```
40 |
--------------------------------------------------------------------------------
/utils/render.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import torch
3 | import numpy as np
4 |
5 | def play(env, policy, running_state=None, video_path="tmp.avi", time_limit=999, device='cpu'):
6 | out = None
7 | obs = env.reset()
8 | if running_state is not None:
9 | obs = running_state(obs, update=False)
10 | num = 0
11 |
12 | while True:
13 | img = env.unwrapped.render(mode='rgb_array')[:, :, ::-1].copy()
14 | if out is None:
15 | out = cv2.VideoWriter(
16 | video_path, cv2.VideoWriter_fourcc(*'XVID'), 20.0, (img.shape[1], img.shape[0]))
17 | out.write(img)
18 | if policy is not None:
19 | obs = torch.tensor(obs).float().unsqueeze(0).to(device)
20 | action = policy.select_action(obs)[0].detach().cpu().numpy()
21 | action = int(action) if policy.is_disc_action else action.astype(np.float32)
22 | else:
23 | action = env.action_space.sample()
24 | obs, rew, done, info = env.step(action)
25 | if running_state is not None:
26 | obs = running_state(obs, update=False)
27 | if done:
28 | obs = env.reset()
29 | num += 1
30 | #assert not info['is_success']
31 | flag = True
32 | if not flag:
33 | print(num, info, rew, done, env.goal, action)
34 | if num == time_limit - 1:
35 | break
36 | env.close()
37 |
38 | def play_action_seq(env, action_seq, video_path="tmp.avi", time_limit=999):
39 | out = None
40 | obs = env.reset()
41 | t = 0
42 | reward = 0
43 | while True:
44 | img = env.unwrapped.render(mode='rgb_array')[:, :, ::-1].copy()
45 | if out is None:
46 | out = cv2.VideoWriter(
47 | video_path, cv2.VideoWriter_fourcc(*'XVID'), 20.0, (img.shape[1], img.shape[0]))
48 | out.write(img)
49 | action = action_seq[t]
50 | obs, rew, done, info = env.step(action)
51 | reward+=rew
52 | t+=1
53 | if t > time_limit:
54 | print('accumulated reward', reward)
55 | break
56 | env.close()
--------------------------------------------------------------------------------
/utils/torch.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from models.dynamics import ForwardModel, InverseModel
4 |
5 | tensor = torch.tensor
6 | DoubleTensor = torch.DoubleTensor
7 | FloatTensor = torch.FloatTensor
8 | LongTensor = torch.LongTensor
9 | ByteTensor = torch.ByteTensor
10 | ones = torch.ones
11 | zeros = torch.zeros
12 |
13 |
14 | def to_device(device, *args):
15 | for x in args:
16 | if isinstance(x, ForwardModel) or isinstance(x, InverseModel):
17 | x.device = device
18 | return [x.to(device) for x in args]
19 |
20 |
21 | def get_flat_params_from(model):
22 | params = []
23 | for param in model.parameters():
24 | params.append(param.view(-1))
25 |
26 | flat_params = torch.cat(params)
27 | return flat_params
28 |
29 |
30 | def set_flat_params_to(model, flat_params):
31 | prev_ind = 0
32 | for param in model.parameters():
33 | flat_size = int(np.prod(list(param.size())))
34 | param.data.copy_(
35 | flat_params[prev_ind:prev_ind + flat_size].view(param.size()))
36 | prev_ind += flat_size
37 |
38 |
39 | def get_flat_grad_from(inputs, grad_grad=False):
40 | grads = []
41 | for param in inputs:
42 | if grad_grad:
43 | grads.append(param.grad.grad.view(-1))
44 | else:
45 | if param.grad is None:
46 | grads.append(zeros(param.view(-1).shape))
47 | else:
48 | grads.append(param.grad.view(-1))
49 |
50 | flat_grad = torch.cat(grads)
51 | return flat_grad
52 |
53 |
54 | def compute_flat_grad(output, inputs, filter_input_ids=set(), retain_graph=False, create_graph=False):
55 | if create_graph:
56 | retain_graph = True
57 |
58 | inputs = list(inputs)
59 | params = []
60 | for i, param in enumerate(inputs):
61 | if i not in filter_input_ids:
62 | params.append(param)
63 |
64 | grads = torch.autograd.grad(output, params, retain_graph=retain_graph, create_graph=create_graph)
65 |
66 | j = 0
67 | out_grads = []
68 | for i, param in enumerate(inputs):
69 | if i in filter_input_ids:
70 | out_grads.append(zeros(param.view(-1).shape, device=param.device, dtype=param.dtype))
71 | else:
72 | out_grads.append(grads[j].view(-1))
73 | j += 1
74 | grads = torch.cat(out_grads)
75 |
76 | for param in params:
77 | param.grad = None
78 | return grads
79 |
--------------------------------------------------------------------------------
/envs/mujoco/assets/light_swimmer.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
--------------------------------------------------------------------------------
/envs/mujoco/assets/swimmer.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
--------------------------------------------------------------------------------
/envs/mujoco/assets/heavy_swimmer.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
--------------------------------------------------------------------------------
/envs/mujoco/assets/disabled_swimmer.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
--------------------------------------------------------------------------------
/models/WGAN.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 |
4 | class Generator(nn.Module):
5 |
6 | def __init__(self, input_dim, output_dim, hidden_size=400):
7 | super(Generator, self).__init__()
8 | main = nn.Sequential(
9 | nn.Linear(input_dim, hidden_size),
10 | nn.ReLU(),
11 | nn.Linear(hidden_size, hidden_size),
12 | nn.ReLU(),
13 | nn.Linear(hidden_size, hidden_size),
14 | nn.ReLU(),
15 | nn.Linear(hidden_size, output_dim),
16 | )
17 | self.main = main
18 |
19 | def forward(self, noise):
20 | output = self.main(noise)
21 | return output
22 |
23 | class Discriminator(nn.Module):
24 | def __init__(self, num_inputs, layers=3, hidden_size=400, activation='tanh'):
25 | super(Discriminator, self).__init__()
26 | if activation == 'tanh':
27 | self.activation = torch.tanh
28 | elif activation == 'relu':
29 | self.activation = torch.relu
30 | elif activation == 'sigmoid':
31 | self.activation = torch.sigmoid
32 | self.num_layers = layers
33 | self.affine_layers = nn.ModuleList()
34 |
35 | for i in range(self.num_layers):
36 | if i == 0:
37 | self.affine_layers.append(nn.Linear(num_inputs, hidden_size))
38 | else:
39 | self.affine_layers.append(nn.Linear(hidden_size, hidden_size))
40 |
41 | self.logic = nn.Linear(hidden_size, 1)
42 | self.logic.weight.data.mul_(0.1)
43 | self.logic.bias.data.mul_(0.0)
44 |
45 | def forward(self, x):
46 | for affine in self.affine_layers:
47 | x = self.activation(affine(x))
48 | prob = torch.sigmoid(self.logic(x))
49 | return prob
50 |
51 | #for wgan: remove sigmoid in the last layer of discriminator
52 | class W_Discriminator(nn.Module):
53 |
54 | def __init__(self, input_dim, hidden_size=400, layers=3):
55 | super(W_Discriminator, self).__init__()
56 | self.activation = torch.relu
57 | self.num_layers = layers
58 | self.affine_layers = nn.ModuleList()
59 |
60 | for i in range(self.num_layers):
61 | if i == 0:
62 | self.affine_layers.append(nn.Linear(input_dim, hidden_size))
63 | else:
64 | self.affine_layers.append(nn.Linear(hidden_size, hidden_size))
65 |
66 | self.final = nn.Linear(hidden_size, 1)
67 | self.final.weight.data.mul_(0.1)
68 | self.final.bias.data.mul_(0.0)
69 |
70 | def forward(self, x):
71 | for affine in self.affine_layers:
72 | x = self.activation(affine(x))
73 | res = self.final(x)
74 | return res.view(-1)
75 |
76 | def weights_init(m):
77 | classname = m.__class__.__name__
78 | if classname.find('Linear') != -1:
79 | m.weight.data.normal_(0.0, 0.02)
80 | m.bias.data.fill_(0)
81 | elif classname.find('BatchNorm') != -1:
82 | m.weight.data.normal_(1.0, 0.02)
83 | m.bias.data.fill_(0)
--------------------------------------------------------------------------------
/expert/sac_args.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | def get_sac_args():
4 | parser = argparse.ArgumentParser(description='PyTorch GAIL example')
5 | parser.add_argument('--env-name', default="Hopper-v2",
6 | help='name of the environment to run')
7 | parser.add_argument('--policy', default="Gaussian",
8 | help='algorithm to use: Gaussian | Deterministic')
9 | parser.add_argument('--eval', type=bool, default=True,
10 | help='Evaluates a policy a policy every 10 episode (default:True)')
11 | parser.add_argument('--gamma', type=float, default=0.99, metavar='G',
12 | help='discount factor for reward (default: 0.99)')
13 | parser.add_argument('--tau', type=float, default=0.005, metavar='G',
14 | help='target smoothing coefficient(τ) (default: 0.005)')
15 | parser.add_argument('--lr', type=float, default=0.0003, metavar='G',
16 | help='learning rate (default: 0.0003)')
17 | parser.add_argument('--alpha', type=float, default=0.2, metavar='G',
18 | help='Temperature parameter α determines the relative importance of the entropy term against the reward (default: 0.2)')
19 | parser.add_argument('--automatic_entropy_tuning', type=bool, default=False, metavar='G',
20 | help='Temperature parameter α automaically adjusted.')
21 | parser.add_argument('--seed', type=int, default=456, metavar='N',
22 | help='random seed (default: 456)')
23 | parser.add_argument('--batch-size', type=int, default=256, metavar='N',
24 | help='batch size (default: 256)')
25 | parser.add_argument('--num-steps', type=int, default=1000001, metavar='N',
26 | help='maximum number of steps (default: 1000000)')
27 | parser.add_argument('--hidden-size', type=int, default=400, metavar='N',
28 | help='hidden size (default: 256)')
29 | parser.add_argument('--updates-per-step', type=int, default=1, metavar='N',
30 | help='model updates per simulator step (default: 1)')
31 | parser.add_argument('--start-steps', type=int, default=300, metavar='N',
32 | help='Steps sampling random actions (default: 10000)')
33 | parser.add_argument('--target-update-interval', type=int, default=1, metavar='N',
34 | help='Value target update per no. of updates per step (default: 1)')
35 | parser.add_argument('--replay-size', type=int, default=1000000, metavar='N',
36 | help='size of replay buffer (default: 10000000)')
37 | parser.add_argument('--device', type=str, default="cuda:0",
38 | help='run on CUDA (default: False)')
39 | parser.add_argument('--resume', type=bool, default=False,
40 | help='run on CUDA (default: False)')
41 | parser.add_argument('--model-path', type=str, default='learned_models/sac')
42 | parser.add_argument('--expert-traj-path', metavar='G',
43 | help='path of the expert trajectories')
44 | parser.add_argument('--rnd-epoch', type=int, default=600,
45 | help='path of the expert trajectories')
46 | args = parser.parse_args()
47 |
48 | return args
--------------------------------------------------------------------------------
/expert/save_traj_ppo.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import gym
3 | import os
4 | import sys
5 | import pickle
6 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
7 |
8 | from itertools import count
9 | from utils import *
10 |
11 | parser = argparse.ArgumentParser(description='Save expert trajectory')
12 | parser.add_argument('--env-name', default="Hopper-v2", metavar='G',
13 | help='name of the environment to run')
14 | parser.add_argument('--model-path', metavar='G',
15 | help='name of the expert model')
16 | parser.add_argument('--render', action='store_true', default=False,
17 | help='render the environment')
18 | parser.add_argument('--seed', type=int, default=1, metavar='N',
19 | help='random seed (default: 1)')
20 | parser.add_argument('--max-expert-state-num', type=int, default=50000, metavar='N',
21 | help='maximal number of main iterations (default: 50000)')
22 | parser.add_argument('--running-state', type=int, default=0)
23 | args = parser.parse_args()
24 |
25 | dtype = torch.float32
26 | torch.set_default_dtype(dtype)
27 | env = gym.make(args.env_name)
28 | env.seed(args.seed)
29 | torch.manual_seed(args.seed)
30 | is_disc_action = len(env.action_space.shape) == 0
31 | state_dim = env.observation_space.shape[0]
32 |
33 | if args.running_state == 1:
34 | print('use running state')
35 | policy_net, _, running_state = pickle.load(open(args.model_path, "rb"))
36 | else:
37 | print('no running state')
38 | policy_net, _ = pickle.load(open(args.model_path, "rb"))
39 |
40 | expert_trajs = []
41 | policy_net.to(dtype)
42 | def main_loop():
43 |
44 | num_steps = 0
45 |
46 | for i_episode in count():
47 | expert_traj = []
48 | state = env.reset()
49 | if args.running_state:
50 | state = running_state(state)
51 | reward_episode = 0
52 |
53 | for t in range(10000):
54 | state_var = tensor(state).unsqueeze(0).to(dtype)
55 | # choose mean action
56 | action = policy_net(state_var)[0][0].detach().numpy()
57 | # choose stochastic action
58 | # action = policy_net.select_action(state_var)[0].cpu().numpy()
59 | action = int(action) if is_disc_action else action.astype(np.float64)
60 | next_state, reward, done, _ = env.step(action)
61 | if args.running_state:
62 | next_state = running_state(next_state)
63 | reward_episode += reward
64 | num_steps += 1
65 |
66 | expert_traj.append(np.hstack([state, action]))
67 |
68 | if args.render:
69 | env.render()
70 | if done:
71 | expert_traj = np.stack(expert_traj)
72 | expert_trajs.append(expert_traj)
73 | break
74 |
75 | state = next_state
76 |
77 | print('Episode {}\t reward: {:.2f}'.format(i_episode, reward_episode))
78 |
79 | if num_steps >= args.max_expert_state_num:
80 | break
81 |
82 |
83 | main_loop()
84 | if args.running_state:
85 | pickle.dump((expert_trajs, running_state), open(os.path.join(assets_dir(), 'expert_traj/{}_expert_traj.p'.format(args.env_name)), \
86 | 'wb'))
87 | else:
88 | pickle.dump(expert_trajs, open(os.path.join(assets_dir(), 'expert_traj/{}_expert_traj.p'.format(args.env_name)),\
89 | 'wb'))
90 |
--------------------------------------------------------------------------------
/expert/save_traj_sac.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import gym
3 | import itertools
4 | from agents.sac_agent import SAC_agent
5 | from utils import *
6 | import argparse
7 |
8 | def get_args():
9 | parser = argparse.ArgumentParser(description='PyTorch GAIL example')
10 | parser.add_argument('--env-name', default="Hopper-v2",
11 | help='name of the environment to run')
12 | parser.add_argument('--policy', default="Gaussian",
13 | help='algorithm to use: Gaussian | Deterministic')
14 | parser.add_argument('--eval', type=bool, default=True,
15 | help='Evaluates a policy a policy every 10 episode (default:True)')
16 | parser.add_argument('--gamma', type=float, default=0.99, metavar='G',
17 | help='discount factor for reward (default: 0.99)')
18 | parser.add_argument('--tau', type=float, default=0.005, metavar='G',
19 | help='target smoothing coefficient(τ) (default: 0.005)')
20 | parser.add_argument('--lr', type=float, default=0.0003, metavar='G',
21 | help='learning rate (default: 0.0003)')
22 | parser.add_argument('--alpha', type=float, default=0.2, metavar='G',
23 | help='Temperature parameter α determines the relative importance of the entropy term against the reward (default: 0.2)')
24 | parser.add_argument('--automatic_entropy_tuning', type=bool, default=False, metavar='G',
25 | help='Temperature parameter α automaically adjusted.')
26 | parser.add_argument('--seed', type=int, default=456, metavar='N',
27 | help='random seed (default: 456)')
28 | parser.add_argument('--batch-size', type=int, default=256, metavar='N',
29 | help='batch size (default: 256)')
30 | parser.add_argument('--num-steps', type=int, default=1000001, metavar='N',
31 | help='maximum number of steps (default: 1000000)')
32 | parser.add_argument('--hidden-size', type=int, default=400, metavar='N',
33 | help='hidden size (default: 256)')
34 | parser.add_argument('--updates-per-step', type=int, default=1, metavar='N',
35 | help='model updates per simulator step (default: 1)')
36 | parser.add_argument('--start-steps', type=int, default=300, metavar='N',
37 | help='Steps sampling random actions (default: 10000)')
38 | parser.add_argument('--target-update-interval', type=int, default=1, metavar='N',
39 | help='Value target update per no. of updates per step (default: 1)')
40 | parser.add_argument('--replay-size', type=int, default=1e6, metavar='N',
41 | help='size of replay buffer (default: 10000000)')
42 | parser.add_argument('--device', type=str, default="cuda:0",
43 | help='run on CUDA (default: False)')
44 | parser.add_argument('--actor-path', type=str, default='assets/learned_models/sac_actor_Hopper-v2_1', help='actor resume path')
45 | parser.add_argument('--critic-path', type=str, default='assets/learned_models/sac_critic_Hopper-v2_1', help='critic resume path')
46 |
47 | args = parser.parse_args()
48 |
49 | return args
50 |
51 | args = get_args()
52 |
53 | # Environment
54 | # env = NormalizedActions(gym.make(args.env_name))
55 | env = gym.make(args.env_name)
56 | torch.manual_seed(args.seed)
57 | np.random.seed(args.seed)
58 | env.seed(args.seed)
59 | state_dim = env.observation_space.shape[0]
60 | agent = SAC_agent(env, env.observation_space.shape[0], env.action_space, args, running_state=None)
61 | agent.load_model(actor_path=args.actor_path, critic_path=args.critic_path)
62 | agent.save_expert_traj(max_step=50000)
--------------------------------------------------------------------------------
/expert/train_sac.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import gym
3 | import itertools
4 | from agents.sac_agent import SAC_agent
5 | from tensorboardX import SummaryWriter
6 | from models.rnd_model import RND
7 | from expert.sac_args import get_sac_args
8 | from utils import *
9 |
10 | args = get_sac_args()
11 |
12 | # Environment
13 | # env = NormalizedActions(gym.make(args.env_name))
14 | env = gym.make(args.env_name)
15 | torch.manual_seed(args.seed)
16 | np.random.seed(args.seed)
17 | env.seed(args.seed)
18 | state_dim = env.observation_space.shape[0]
19 | agent = SAC_agent(env, env.observation_space.shape[0], env.action_space, args, running_state=None)
20 | writer = SummaryWriter(log_dir='runs/{}_SAC_{}_{}_{}'.format(datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"), args.env_name,
21 | args.policy, "autotune" if args.automatic_entropy_tuning else ""))
22 | # Memory
23 | memory = Memory(args.replay_size)
24 |
25 | # Training Loop
26 | total_numsteps = 0
27 | updates = 0
28 | #agent.save_expert_traj()
29 | for i_episode in itertools.count(1):
30 | episode_reward = 0
31 | episode_steps = 0
32 | done = False
33 | state = env.reset()
34 |
35 | while not done:
36 | if args.start_steps > total_numsteps:
37 | action = env.action_space.sample() # Sample random action
38 | else:
39 | action = agent.select_action(state) # Sample action from policy
40 |
41 | if len(memory) > args.batch_size:
42 | # Number of updates per step in environment
43 | for i in range(args.updates_per_step):
44 | # Update parameters of all the networks
45 | critic_1_loss, critic_2_loss, policy_loss, ent_loss, alpha = agent.update_parameters(memory, args.batch_size, updates)
46 |
47 | writer.add_scalar('loss/critic_1', critic_1_loss, updates)
48 | writer.add_scalar('loss/critic_2', critic_2_loss, updates)
49 | writer.add_scalar('loss/policy', policy_loss, updates)
50 | writer.add_scalar('loss/entropy_loss', ent_loss, updates)
51 | writer.add_scalar('entropy_temprature/alpha', alpha, updates)
52 | updates += 1
53 |
54 | next_state, reward, done, _ = env.step(action) # Step
55 | episode_steps += 1
56 | total_numsteps += 1
57 | episode_reward += reward
58 |
59 | mask = 1 if episode_steps == env._max_episode_steps else float(not done)
60 |
61 | memory.push(state, action, reward, next_state, mask) # Append transition to memory
62 |
63 | state = next_state
64 |
65 | if total_numsteps > args.num_steps:
66 | break
67 |
68 | writer.add_scalar('reward/sac_train', episode_reward, total_numsteps)
69 | print("Episode: {}, total numsteps: {}, episode steps: {}, reward: {}".format(i_episode, total_numsteps, episode_steps, round(episode_reward, 2)))
70 |
71 | if i_episode % 10 == 0 and args.eval == True:
72 | avg_reward = 0.
73 | episodes = 10
74 | for _ in range(episodes):
75 | state = env.reset()
76 | episode_reward = 0
77 | done = False
78 | while not done:
79 | action = agent.select_action(state, eval=True)
80 |
81 | next_state, reward, done, _ = env.step(action)
82 | episode_reward += reward
83 |
84 |
85 | state = next_state
86 | avg_reward += episode_reward
87 | avg_reward /= episodes
88 | writer.add_scalar('avg_reward/sac_test', avg_reward, total_numsteps)
89 |
90 | print("----------------------------------------")
91 | print("Test Episodes: {}, Avg. Reward: {}".format(episodes, round(avg_reward, 2)))
92 | print("----------------------------------------")
93 |
94 | if i_episode % 1000 == 1:
95 | agent.save_model(args.env_name, i_episode)
96 | # get trajectory of the expert
97 | # agent.save_expert_traj(max_step=50000)
98 |
99 | env.close()
--------------------------------------------------------------------------------
/envs/mujoco/swimmer.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | from __future__ import print_function
3 | from __future__ import absolute_import
4 |
5 | import os
6 | import numpy as np
7 | from gym import utils
8 | from gym.envs.mujoco import mujoco_env
9 |
10 | class DisableSwimmerEnv(mujoco_env.MujocoEnv, utils.EzPickle):
11 | def __init__(self):
12 | dir_path = os.path.dirname(os.path.realpath(__file__))
13 | mujoco_env.MujocoEnv.__init__(self, '%s/assets/disabled_swimmer.xml' % dir_path, 4)
14 | utils.EzPickle.__init__(self)
15 |
16 | def step(self, a):
17 | ctrl_cost_coeff = 0.0001
18 | xposbefore = self.sim.data.qpos[0]
19 | self.do_simulation(a, self.frame_skip)
20 | xposafter = self.sim.data.qpos[0]
21 | reward_fwd = (xposafter - xposbefore) / self.dt
22 | reward_ctrl = - ctrl_cost_coeff * np.square(a).sum()
23 | reward = reward_fwd + reward_ctrl
24 | ob = self._get_obs()
25 | return ob, reward, False, dict(reward_fwd=reward_fwd, reward_ctrl=reward_ctrl)
26 |
27 | def _get_obs(self):
28 | qpos = self.sim.data.qpos
29 | qvel = self.sim.data.qvel
30 | return np.concatenate([qpos.flat[2:], qvel.flat])
31 |
32 | def reset_model(self):
33 | self.set_state(
34 | self.init_qpos + self.np_random.uniform(low=-.1, high=.1, size=self.model.nq),
35 | self.init_qvel + self.np_random.uniform(low=-.1, high=.1, size=self.model.nv)
36 | )
37 | return self._get_obs()
38 |
39 | class HeavySwimmerEnv(mujoco_env.MujocoEnv, utils.EzPickle):
40 | def __init__(self):
41 | dir_path = os.path.dirname(os.path.realpath(__file__))
42 | mujoco_env.MujocoEnv.__init__(self, '%s/assets/heavy_swimmer.xml' % dir_path, 4)
43 | utils.EzPickle.__init__(self)
44 |
45 | def step(self, a):
46 | ctrl_cost_coeff = 0.0001
47 | xposbefore = self.sim.data.qpos[0]
48 | self.do_simulation(a, self.frame_skip)
49 | xposafter = self.sim.data.qpos[0]
50 | reward_fwd = (xposafter - xposbefore) / self.dt
51 | reward_ctrl = - ctrl_cost_coeff * np.square(a).sum()
52 | reward = reward_fwd + reward_ctrl
53 | ob = self._get_obs()
54 | return ob, reward, False, dict(reward_fwd=reward_fwd, reward_ctrl=reward_ctrl)
55 |
56 | def _get_obs(self):
57 | qpos = self.sim.data.qpos
58 | qvel = self.sim.data.qvel
59 | return np.concatenate([qpos.flat[2:], qvel.flat])
60 |
61 | def reset_model(self):
62 | self.set_state(
63 | self.init_qpos + self.np_random.uniform(low=-.1, high=.1, size=self.model.nq),
64 | self.init_qvel + self.np_random.uniform(low=-.1, high=.1, size=self.model.nv)
65 | )
66 | return self._get_obs()
67 |
68 | class LightSwimmerEnv(mujoco_env.MujocoEnv, utils.EzPickle):
69 | def __init__(self):
70 | dir_path = os.path.dirname(os.path.realpath(__file__))
71 | mujoco_env.MujocoEnv.__init__(self, '%s/assets/light_swimmer.xml' % dir_path, 4)
72 | utils.EzPickle.__init__(self)
73 |
74 | def step(self, a):
75 | ctrl_cost_coeff = 0.0001
76 | xposbefore = self.sim.data.qpos[0]
77 | self.do_simulation(a, self.frame_skip)
78 | xposafter = self.sim.data.qpos[0]
79 | reward_fwd = (xposafter - xposbefore) / self.dt
80 | reward_ctrl = - ctrl_cost_coeff * np.square(a).sum()
81 | reward = reward_fwd + reward_ctrl
82 | ob = self._get_obs()
83 | return ob, reward, False, dict(reward_fwd=reward_fwd, reward_ctrl=reward_ctrl)
84 |
85 | def _get_obs(self):
86 | qpos = self.sim.data.qpos
87 | qvel = self.sim.data.qvel
88 | return np.concatenate([qpos.flat[2:], qvel.flat])
89 |
90 | def reset_model(self):
91 | self.set_state(
92 | self.init_qpos + self.np_random.uniform(low=-.1, high=.1, size=self.model.nq),
93 | self.init_qvel + self.np_random.uniform(low=-.1, high=.1, size=self.model.nv)
94 | )
95 | return self._get_obs()
--------------------------------------------------------------------------------
/models/VAE.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import Variable
3 | import numpy as np
4 | import torch.nn.functional as F
5 | import torchvision
6 | from torchvision import transforms
7 | import torch.optim as optim
8 | from torch import nn
9 | import matplotlib.pyplot as plt
10 | from torch.distributions import Normal
11 | from models.ppo_models import weights_init_
12 |
13 | MAX_LOG_STD = 0.5
14 | MIN_LOG_STD = -20
15 |
16 | def latent_loss(z_mean, z_stddev):
17 | mean_sq = z_mean * z_mean
18 | stddev_sq = z_stddev * z_stddev
19 | return 0.5 * torch.mean(mean_sq + stddev_sq - torch.log(stddev_sq) - 1)
20 |
21 | class Encoder(torch.nn.Module):
22 | def __init__(self, input_dim, latent_dim, hidden_size=128):
23 | super(Encoder, self).__init__()
24 | self.linear1 = nn.Linear(input_dim, hidden_size)
25 | self.linear2 = nn.Linear(hidden_size, hidden_size)
26 | self.linear3 = nn.Linear(hidden_size, hidden_size)
27 | self.mu = nn.Linear(hidden_size, latent_dim)
28 | self.log_std = nn.Linear(hidden_size, latent_dim)
29 | self.apply(weights_init_)
30 |
31 | def forward(self, x):
32 | x = F.relu(self.linear1(x))
33 | x = F.relu(self.linear2(x))
34 | x = F.relu(self.linear3(x))
35 | mean = self.mu(x)
36 | log_std = self.log_std(x)
37 | log_std = torch.clamp(log_std, min=MIN_LOG_STD, max=MAX_LOG_STD)
38 | return mean, log_std
39 |
40 |
41 | class Decoder(torch.nn.Module):
42 | def __init__(self, input_dim, out_dim, hidden_size=128):
43 | super(Decoder, self).__init__()
44 | self.linear1 = torch.nn.Linear(input_dim, hidden_size)
45 | self.linear2 = torch.nn.Linear(hidden_size, hidden_size)
46 | self.linear3 = torch.nn.Linear(hidden_size, out_dim)
47 | self.apply(weights_init_)
48 |
49 | def forward(self, x):
50 | x = F.relu(self.linear1(x))
51 | x = F.relu(self.linear2(x))
52 | x = self.linear3(x)
53 | return x
54 |
55 |
56 | class VAE(torch.nn.Module):
57 | def __init__(self, state_dim, hidden_size=128, latent_dim=128):
58 | super(VAE, self).__init__()
59 | self.hidden_size = hidden_size
60 | self.encoder = Encoder(state_dim, latent_dim=latent_dim, hidden_size=self.hidden_size)
61 | self.decoder = Decoder(latent_dim, state_dim, hidden_size=self.hidden_size)
62 |
63 | def forward(self, state):
64 | mu, log_sigma = self.encoder(state)
65 | sigma = torch.exp(log_sigma)
66 | sample = mu + torch.randn_like(mu)*sigma
67 | self.z_mean = mu
68 | self.z_sigma = sigma
69 |
70 | return self.decoder(sample)
71 |
72 | def to(self, device):
73 | self.encoder.to(device)
74 | self.decoder.to(device)
75 |
76 | def get_next_states(self, states):
77 | mu, log_sigma = self.encoder(states)
78 | return self.decoder(mu)
79 |
80 | def get_loss(self, state, next_state):
81 | next_pred = self.get_next_states(state)
82 | return ((next_state-next_pred)**2).mean()
83 |
84 | def train(self, input, target, epoch, optimizer, batch_size=128, beta=0.1):
85 | idxs = np.arange(input.shape[0])
86 | np.random.shuffle(idxs)
87 | num_batch = int(np.ceil(idxs.shape[-1] / batch_size))
88 | for epoch in range(epoch):
89 | idxs = np.arange(input.shape[0])
90 | np.random.shuffle(idxs)
91 | for batch_num in range(num_batch):
92 | batch_idxs = idxs[batch_num * batch_size : (batch_num + 1) * batch_size]
93 | train_in = input[batch_idxs].float()
94 | train_targ = target[batch_idxs].float()
95 | optimizer.zero_grad()
96 | dec = self.forward(train_in)
97 | reconstruct_loss = ((train_targ-dec)**2).mean()
98 | ll = latent_loss(self.z_mean, self.z_sigma)
99 | loss = reconstruct_loss + beta*ll
100 | loss.backward()
101 | optimizer.step()
102 | val_input = input[idxs]
103 | val_target = target[idxs]
104 | val_dec = self.get_next_states(val_input)
105 | loss = ((val_target-val_dec)**2).mean().item()
106 | #print('vae loss', loss)
107 | return loss
--------------------------------------------------------------------------------
/core/agent.py:
--------------------------------------------------------------------------------
1 | from utils.replay_memory import OnlineMemory
2 | from utils.torch import *
3 | import math
4 | import time
5 |
6 | def collect_samples(env, policy, custom_reward,
7 | mean_action, render, running_state, min_batch_size, update, trajectory, dtype=torch.float32):
8 | log = dict()
9 | memory = OnlineMemory()
10 | num_steps = 0
11 | total_reward = 0
12 | min_reward = 1e6
13 | max_reward = -1e6
14 | total_c_reward = 0
15 | min_c_reward = 1e6
16 | max_c_reward = -1e6
17 | num_episodes = 0
18 | if dtype == torch.float32:
19 | numpy_dtype = np.float32
20 | else:
21 | numpy_dtype = np.float64
22 |
23 | while num_steps < min_batch_size:
24 | state = env.reset()
25 | episode_steps = 0
26 | if running_state is not None:
27 | state = running_state(state, update=update)
28 | reward_episode = 0
29 |
30 | for t in range(10000):
31 | state_var = tensor(state).to(dtype).unsqueeze(0)
32 | with torch.no_grad():
33 | if mean_action:
34 | action = policy(state_var)[0][0].numpy()
35 | else:
36 | action = policy.select_action(state_var)[0].numpy()
37 | action = int(action) if policy.is_disc_action else action.astype(numpy_dtype)
38 | next_state, reward, done, _ = env.step(action)
39 | episode_steps += 1
40 | reward_episode += reward
41 |
42 | if running_state is not None:
43 | next_state = running_state(next_state, update=update)
44 | #print(next_state, running_state.rs.n)
45 |
46 | if episode_steps == env._max_episode_steps:
47 | mask = 1.
48 | else:
49 | mask = float(not done)
50 |
51 | if custom_reward is not None:
52 | reward = custom_reward(state, action, next_state, done)
53 | total_c_reward += reward
54 | min_c_reward = min(min_c_reward, reward)
55 | max_c_reward = max(max_c_reward, reward)
56 |
57 | memory.push(state, action, reward, next_state, mask)
58 | if trajectory is not None:
59 | trajectory.push(state, action, reward, next_state, mask)
60 |
61 | if render:
62 | env.render()
63 | if done:
64 | if trajectory is not None:
65 | trajectory.clear()
66 | break
67 |
68 | state = next_state
69 |
70 | # log stats
71 | num_steps += (t + 1)
72 | num_episodes += 1
73 | total_reward += reward_episode
74 | min_reward = min(min_reward, reward_episode)
75 | max_reward = max(max_reward, reward_episode)
76 |
77 | log['num_steps'] = num_steps
78 | log['num_episodes'] = num_episodes
79 | log['total_reward'] = total_reward
80 | log['avg_reward'] = total_reward / num_episodes
81 | log['max_reward'] = max_reward
82 | log['min_reward'] = min_reward
83 | if custom_reward is not None:
84 | log['total_c_reward'] = total_c_reward
85 | log['avg_c_reward'] = total_c_reward / num_episodes
86 | log['max_c_reward'] = max_c_reward
87 | log['min_c_reward'] = min_c_reward
88 |
89 | return memory, log
90 |
91 | class Agent:
92 |
93 | def __init__(self, env, policy, device, custom_reward=None,
94 | mean_action=False, render=False, running_state=None, update=True, dtype=torch.float32):
95 | self.env = env
96 | self.policy = policy
97 | self.device = device
98 | self.custom_reward = custom_reward
99 | self.mean_action = mean_action
100 | self.running_state = running_state
101 | self.render = render
102 | self.update = update
103 | self.dtype = dtype
104 |
105 | def collect_samples(self, min_batch_size, trajectory=None):
106 | t_start = time.time()
107 | to_device(torch.device('cpu'), self.policy)
108 | memory, log = collect_samples(self.env, self.policy, self.custom_reward, self.mean_action,
109 | self.render, self.running_state, min_batch_size, self.update, trajectory, self.dtype)
110 |
111 | batch = memory.sample()
112 | to_device(self.device, self.policy)
113 | t_end = time.time()
114 | log['sample_time'] = t_end - t_start
115 | log['action_mean'] = np.mean(np.vstack(batch.action), axis=0)
116 | log['action_min'] = np.min(np.vstack(batch.action), axis=0)
117 | log['action_max'] = np.max(np.vstack(batch.action), axis=0)
118 | return batch, log
119 |
--------------------------------------------------------------------------------
/models/sac_models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.distributions import Normal
5 |
6 | LOG_SIG_MAX = 2
7 | LOG_SIG_MIN = -20
8 | epsilon = 1e-6
9 |
10 | # Initialize Policy weights
11 | def weights_init_(m):
12 | if isinstance(m, nn.Linear):
13 | torch.nn.init.xavier_uniform_(m.weight, gain=1)
14 | torch.nn.init.constant_(m.bias, 0)
15 |
16 |
17 | class QNetwork(nn.Module):
18 | def __init__(self, num_inputs, num_actions, hidden_dim):
19 | super(QNetwork, self).__init__()
20 |
21 | # Q1 architecture
22 | self.linear1 = nn.Linear(num_inputs + num_actions, hidden_dim)
23 | self.linear2 = nn.Linear(hidden_dim, hidden_dim)
24 | self.linear3 = nn.Linear(hidden_dim, 1)
25 |
26 | # Q2 architecture
27 | self.linear4 = nn.Linear(num_inputs + num_actions, hidden_dim)
28 | self.linear5 = nn.Linear(hidden_dim, hidden_dim)
29 | self.linear6 = nn.Linear(hidden_dim, 1)
30 |
31 | self.apply(weights_init_)
32 |
33 | def forward(self, state, action):
34 | xu = torch.cat([state, action], 1)
35 |
36 | x1 = F.relu(self.linear1(xu))
37 | x1 = F.relu(self.linear2(x1))
38 | x1 = self.linear3(x1)
39 |
40 | x2 = F.relu(self.linear4(xu))
41 | x2 = F.relu(self.linear5(x2))
42 | x2 = self.linear6(x2)
43 |
44 | return x1, x2
45 |
46 |
47 | class GaussianPolicy(nn.Module):
48 | def __init__(self, num_inputs, num_actions, hidden_dim, action_space=None):
49 | super(GaussianPolicy, self).__init__()
50 |
51 | self.linear1 = nn.Linear(num_inputs, hidden_dim)
52 | self.linear2 = nn.Linear(hidden_dim, hidden_dim)
53 |
54 | self.mean_linear = nn.Linear(hidden_dim, num_actions)
55 | self.log_std_linear = nn.Linear(hidden_dim, num_actions)
56 |
57 | self.apply(weights_init_)
58 |
59 | # action rescaling
60 | if action_space is None:
61 | self.action_scale = torch.tensor(1.)
62 | self.action_bias = torch.tensor(0.)
63 | else:
64 | self.action_scale = torch.FloatTensor(
65 | (action_space.high - action_space.low) / 2.)
66 | self.action_bias = torch.FloatTensor(
67 | (action_space.high + action_space.low) / 2.)
68 |
69 | def forward(self, state):
70 | x = F.relu(self.linear1(state))
71 | x = F.relu(self.linear2(x))
72 | mean = self.mean_linear(x)
73 | log_std = self.log_std_linear(x)
74 | log_std = torch.clamp(log_std, min=LOG_SIG_MIN, max=LOG_SIG_MAX)
75 | return mean, log_std
76 |
77 | def sample(self, state):
78 | mean, log_std = self.forward(state)
79 | std = log_std.exp()
80 | normal = Normal(mean, std)
81 | x_t = normal.rsample() # for reparameterization trick (mean + std * N(0,1))
82 | y_t = torch.tanh(x_t)
83 | action = y_t * self.action_scale + self.action_bias
84 | log_prob = normal.log_prob(x_t)
85 | # Enforcing Action Bound
86 | log_prob -= torch.log(self.action_scale * (1 - y_t.pow(2)) + epsilon)
87 | log_prob = log_prob.sum(1, keepdim=True)
88 | return action, log_prob, torch.tanh(mean)
89 |
90 | def to(self, device):
91 | self.action_scale = self.action_scale.to(device)
92 | self.action_bias = self.action_bias.to(device)
93 | return super(GaussianPolicy, self).to(device)
94 |
95 |
96 | class DeterministicPolicy(nn.Module):
97 | def __init__(self, num_inputs, num_actions, hidden_dim, action_space=None):
98 | super(DeterministicPolicy, self).__init__()
99 | self.linear1 = nn.Linear(num_inputs, hidden_dim)
100 | self.linear2 = nn.Linear(hidden_dim, hidden_dim)
101 |
102 | self.mean = nn.Linear(hidden_dim, num_actions)
103 | self.noise = torch.Tensor(num_actions)
104 |
105 | self.apply(weights_init_)
106 |
107 | # action rescaling
108 | if action_space is None:
109 | self.action_scale = 1.
110 | self.action_bias = 0.
111 | else:
112 | self.action_scale = torch.FloatTensor(
113 | (action_space.high - action_space.low) / 2.)
114 | self.action_bias = torch.FloatTensor(
115 | (action_space.high + action_space.low) / 2.)
116 |
117 | def forward(self, state):
118 | x = F.relu(self.linear1(state))
119 | x = F.relu(self.linear2(x))
120 | mean = torch.tanh(self.mean(x)) * self.action_scale + self.action_bias
121 | return mean
122 |
123 | def sample(self, state):
124 | mean = self.forward(state)
125 | noise = self.noise.normal_(0., std=0.1)
126 | noise = noise.clamp(-0.25, 0.25)
127 | action = mean + noise
128 | return action, torch.tensor(0.), mean
129 |
130 | def to(self, device):
131 | self.action_scale = self.action_scale.to(device)
132 | self.action_bias = self.action_bias.to(device)
133 | return super(DeterministicPolicy, self).to(device)
134 |
--------------------------------------------------------------------------------
/envs/mujoco/assets/ant.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
--------------------------------------------------------------------------------
/envs/mujoco/assets/heavy_ant.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
--------------------------------------------------------------------------------
/envs/mujoco/assets/disabled_ant.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
--------------------------------------------------------------------------------
/envs/mujoco/ant.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | from __future__ import print_function
3 | from __future__ import absolute_import
4 |
5 | import os
6 |
7 | import numpy as np
8 | from gym import utils
9 | from gym.envs.mujoco import mujoco_env
10 |
11 | class DisableAntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
12 | def __init__(self):
13 | dir_path = os.path.dirname(os.path.realpath(__file__))
14 | mujoco_env.MujocoEnv.__init__(self, '%s/assets/disabled_ant.xml' % dir_path, 5)
15 | utils.EzPickle.__init__(self)
16 |
17 | def step(self, a):
18 | xposbefore = self.get_body_com("torso")[0]
19 | self.do_simulation(a, self.frame_skip)
20 | xposafter = self.get_body_com("torso")[0]
21 | forward_reward = (xposafter - xposbefore)/self.dt
22 | ctrl_cost = .5 * np.square(a).sum()
23 | contact_cost = 0.5 * 1e-3 * np.sum(
24 | np.square(np.clip(self.sim.data.cfrc_ext, -1, 1)))
25 | survive_reward = 1.0
26 | reward = forward_reward - ctrl_cost - contact_cost + survive_reward
27 | state = self.state_vector()
28 | notdone = np.isfinite(state).all() \
29 | and state[2] >= 0.2 and state[2] <= 1.0
30 | done = not notdone
31 | ob = self._get_obs()
32 | return ob, reward, done, dict(
33 | reward_forward=forward_reward,
34 | reward_ctrl=-ctrl_cost,
35 | reward_contact=-contact_cost,
36 | reward_survive=survive_reward)
37 |
38 | def _get_obs(self):
39 | return np.concatenate([
40 | self.sim.data.qpos.flat[2:],
41 | self.sim.data.qvel.flat,
42 | np.clip(self.sim.data.cfrc_ext, -1, 1).flat,
43 | ])
44 |
45 | def reset_model(self):
46 | qpos = self.init_qpos + self.np_random.uniform(size=self.model.nq, low=-.1, high=.1)
47 | qvel = self.init_qvel + self.np_random.randn(self.model.nv) * .1
48 | self.set_state(qpos, qvel)
49 | return self._get_obs()
50 |
51 | def viewer_setup(self):
52 | self.viewer.cam.distance = self.model.stat.extent * 0.5
53 |
54 |
55 | class HeavyAntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
56 | def __init__(self):
57 | dir_path = os.path.dirname(os.path.realpath(__file__))
58 | mujoco_env.MujocoEnv.__init__(self, '%s/assets/heavy_ant.xml' % dir_path, 5)
59 | utils.EzPickle.__init__(self)
60 |
61 | def step(self, a):
62 | xposbefore = self.get_body_com("torso")[0]
63 | self.do_simulation(a, self.frame_skip)
64 | xposafter = self.get_body_com("torso")[0]
65 | forward_reward = (xposafter - xposbefore)/self.dt
66 | ctrl_cost = .5 * np.square(a).sum()
67 | contact_cost = 0.5 * 1e-3 * np.sum(
68 | np.square(np.clip(self.sim.data.cfrc_ext, -1, 1)))
69 | survive_reward = 1.0
70 | reward = forward_reward - ctrl_cost - contact_cost + survive_reward
71 | state = self.state_vector()
72 | notdone = np.isfinite(state).all() \
73 | and state[2] >= 0.2 and state[2] <= 1.0
74 | done = not notdone
75 | ob = self._get_obs()
76 | return ob, reward, done, dict(
77 | reward_forward=forward_reward,
78 | reward_ctrl=-ctrl_cost,
79 | reward_contact=-contact_cost,
80 | reward_survive=survive_reward)
81 |
82 | def _get_obs(self):
83 | return np.concatenate([
84 | self.sim.data.qpos.flat[2:],
85 | self.sim.data.qvel.flat,
86 | np.clip(self.sim.data.cfrc_ext, -1, 1).flat,
87 | ])
88 |
89 | def reset_model(self):
90 | qpos = self.init_qpos + self.np_random.uniform(size=self.model.nq, low=-.1, high=.1)
91 | qvel = self.init_qvel + self.np_random.randn(self.model.nv) * .1
92 | self.set_state(qpos, qvel)
93 | return self._get_obs()
94 |
95 | def viewer_setup(self):
96 | self.viewer.cam.distance = self.model.stat.extent * 0.5
97 |
98 |
99 | class LightAntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
100 | def __init__(self):
101 | dir_path = os.path.dirname(os.path.realpath(__file__))
102 | mujoco_env.MujocoEnv.__init__(self, '%s/assets/light_ant.xml' % dir_path, 5)
103 | utils.EzPickle.__init__(self)
104 |
105 | def step(self, a):
106 | xposbefore = self.get_body_com("torso")[0]
107 | self.do_simulation(a, self.frame_skip)
108 | xposafter = self.get_body_com("torso")[0]
109 | forward_reward = (xposafter - xposbefore)/self.dt
110 | ctrl_cost = .5 * np.square(a).sum()
111 | contact_cost = 0.5 * 1e-3 * np.sum(
112 | np.square(np.clip(self.sim.data.cfrc_ext, -1, 1)))
113 | survive_reward = 1.0
114 | reward = forward_reward - ctrl_cost - contact_cost + survive_reward
115 | state = self.state_vector()
116 | notdone = np.isfinite(state).all() \
117 | and state[2] >= 0.2 and state[2] <= 1.0
118 | done = not notdone
119 | ob = self._get_obs()
120 | return ob, reward, done, dict(
121 | reward_forward=forward_reward,
122 | reward_ctrl=-ctrl_cost,
123 | reward_contact=-contact_cost,
124 | reward_survive=survive_reward)
125 |
126 | def _get_obs(self):
127 | return np.concatenate([
128 | self.sim.data.qpos.flat[2:],
129 | self.sim.data.qvel.flat,
130 | np.clip(self.sim.data.cfrc_ext, -1, 1).flat,
131 | ])
132 |
133 | def reset_model(self):
134 | qpos = self.init_qpos + self.np_random.uniform(size=self.model.nq, low=-.1, high=.1)
135 | qvel = self.init_qvel + self.np_random.randn(self.model.nv) * .1
136 | self.set_state(qpos, qvel)
137 | return self._get_obs()
138 |
139 | def viewer_setup(self):
140 | self.viewer.cam.distance = self.model.stat.extent * 0.5
--------------------------------------------------------------------------------
/sail.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import sys
4 | import pickle
5 |
6 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
7 |
8 | from utils import *
9 | from utils.utils import generate_pairs, process_expert_traj, generate_tuples, adjust_lr
10 | from agents.soft_bc_agent import SoftBC_agent
11 | from utils.utils import normalize_states, normalize_expert_traj
12 |
13 | def get_args():
14 | parser = argparse.ArgumentParser(description='SAIL arguments')
15 | parser.add_argument('--env-name', default="Hopper-v2", metavar='G',
16 | help='name of the environment to run')
17 | parser.add_argument('--expert-traj-path', metavar='G',
18 | help='path of the expert trajectories')
19 | parser.add_argument('--render', action='store_true', default=False,
20 | help='render the environment')
21 | parser.add_argument('--log-std', type=float, default=-5.0, metavar='G',
22 | help='log std for the policy (default: -1.0)')
23 | parser.add_argument('--gamma', type=float, default=0.99, metavar='G',
24 | help='discount factor (default: 0.99)')
25 | parser.add_argument('--tau', type=float, default=0.95, metavar='G',
26 | help='gae (default: 0.95)')
27 | parser.add_argument('--l2-reg', type=float, default=1e-3, metavar='G',
28 | help='l2 regularization regression (default: 1e-3)')
29 |
30 | parser.add_argument('--policy-lr', type=float, default=3e-4, metavar='G',
31 | help='learning rate for policy networks')
32 | parser.add_argument('--value-lr', type=float, default=3e-4, metavar='G',
33 | help='learning rate for value networks')
34 | parser.add_argument('--model-lr', type=float, default=3e-4, metavar='G',
35 | help='learning rate for forward/inverse/vae')
36 |
37 |
38 | parser.add_argument('--clip-epsilon', type=float, default=0.2, metavar='N',
39 | help='clipping epsilon for PPO')
40 | parser.add_argument('--seed', type=int, default=0, metavar='N',
41 | help='random seed (default: 1)')
42 | parser.add_argument('--min-batch-size', type=int, default=2048, metavar='N',
43 | help='minimal batch size per PPO update (default: 2048)')
44 | parser.add_argument('--max-iter-num', type=int, default=500, metavar='N',
45 | help='maximal number of main iterations (default: 500)')
46 | parser.add_argument('--log-interval', type=int, default=1, metavar='N',
47 | help='interval between training status logs (default: 10)')
48 | parser.add_argument('--save-model-interval', type=int, default=50, metavar='N',
49 | help="interval between saving model (default: 0, means don't save)")
50 | parser.add_argument('--freeze-policy-iter',type=int, default=20,
51 | help="iteration to freeze pretrained policy, so that value network can be more stable")
52 | parser.add_argument('--gpu-index', type=int, default=0, metavar='N')
53 | parser.add_argument('--rnd-epoch', type=int, default=4000, metavar='N')
54 | parser.add_argument('--running-state', type=int, default=1, metavar='N')
55 | parser.add_argument('--optim-epochs', type=int, default=10, metavar='N')
56 | parser.add_argument('--optim-batch-size', type=int, default=128, metavar='N')
57 |
58 | parser.add_argument('--load-running-state', type=int, default=0)
59 | parser.add_argument('--beta', type=float, default=0.005, help='beta VAE coefficient')
60 | parser.add_argument('--resume', action='store_true', default=False,
61 | help='resume pretrained models')
62 | parser.add_argument('--transfer', action='store_true', default=False,
63 | help='if imitator and expert is different, then this should be true')
64 |
65 | '''
66 | hyperparameters for wgan
67 | '''
68 | parser.add_argument('--gan-lr', type=float, default=3e-4)
69 | parser.add_argument('--beta1', type=float, default=0.5)
70 | parser.add_argument('--lam', type=float, default=1.)
71 | parser.add_argument('--gan-batch-size', type=int, default=256)
72 |
73 | parser.add_argument('--value-iter', type=int, default=2)
74 | parser.add_argument('--policy-iter', type=int, default=1)
75 | args = parser.parse_args()
76 | return args
77 | # load trajectory
78 |
79 | if __name__ == '__main__':
80 | args = get_args()
81 | dtype = torch.float32
82 | torch.set_default_dtype(dtype)
83 |
84 |
85 | expert_traj_raw = pickle.load(open(args.expert_traj_path, "rb")) # list of expert trajectories
86 | bc_agent = SoftBC_agent(args)
87 |
88 | if isinstance(expert_traj_raw, np.ndarray):
89 | expert_traj_raw_list = []
90 | for i in range(len(expert_traj_raw)):
91 | expert_traj_raw_list.append(expert_traj_raw[i])
92 | expert_traj_raw = expert_traj_raw_list
93 |
94 | expert_traj = process_expert_traj(expert_traj_raw)
95 | state_pairs = generate_pairs(expert_traj_raw, bc_agent.state_dim, size_per_traj=1000, max_step=1)
96 | state_tuples = generate_tuples(expert_traj_raw, bc_agent.state_dim)
97 |
98 | running_state = bc_agent.preprocess_running_state(expert_traj)
99 | expert_traj = normalize_expert_traj(running_state, expert_traj, bc_agent.state_dim)
100 | state_pairs = normalize_states(running_state, state_pairs, bc_agent.state_dim)
101 | state_tuples = normalize_states(running_state, state_tuples, bc_agent.state_dim)
102 |
103 |
104 | bc_agent.split_data(expert_traj, state_pairs, state_tuples)
105 |
106 | if args.resume is False:
107 | bc_agent.pretrain_vae()
108 | if args.transfer is True:
109 | bc_agent.pretrain_dynamics_with_l2()
110 | else:
111 | print('pretrain with demo')
112 | bc_agent.pretrain_dynamics_with_demo()
113 | bc_agent.save_model()
114 | bc_agent.pretrain_policy()
115 | bc_agent.train()
116 | else:
117 | bc_agent.load_model()
118 | bc_agent.pretrain_policy(epoches=250)
119 | bc_agent.save_model()
120 | bc_agent.train()
121 |
--------------------------------------------------------------------------------
/envs/cycle_4room.py:
--------------------------------------------------------------------------------
1 | import gym
2 | import numpy as np
3 | from gym import spaces
4 |
5 |
6 | class FourRoom(gym.Env):
7 | def __init__(self):
8 | self.n = 11
9 | self.map = np.array([
10 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
11 | 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0,
12 | 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0,
13 | 0, 1, 1, 4, 3, 3, 3, 3, 3, 1, 0,
14 | 0, 1, 1, 4, 1, 0, 1, 1, 2, 1, 0,
15 | 0, 0, 0, 4, 0, 0, 0, 0, 2, 0, 0,
16 | 0, 1, 1, 4, 1, 0, 0, 0, 2, 0, 0,
17 | 0, 1, 1, 4, 1, 0, 1, 1, 2, 1, 0,
18 | 0, 1, 1, 5, 5, 5, 5, 5, 2, 1, 0,
19 | 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0,
20 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
21 | ]).reshape((self.n, self.n))
22 | self.init()
23 |
24 | def init(self):
25 | self.observation_space = spaces.Box(low=0, high=1, shape=(self.n*self.n,), dtype=np.float32)
26 | self.observation_space.n = self.n
27 | self.dx = [0, 1, 0, -1]
28 | self.dy = [1, 0, -1, 0]
29 | self.default_action = np.array([3, 2, 1, 0])
30 | self.inverse_default_action = np.array([1, 0, 3, 2])
31 | self.action_space = spaces.Discrete(len(self.dx))
32 | self.reset()
33 |
34 | def label2obs(self, x, y):
35 | a = np.zeros((self.n*self.n,))
36 | assert self.x < self.n and self.y < self.n
37 | a[x * self.n + y] = 1
38 | return a
39 |
40 | def get_obs(self):
41 | return self.label2obs(self.x, self.y)
42 |
43 | def reset(self):
44 | start = np.where(self.map == 2)
45 | assert len(start) == 2
46 | self.x, self.y = 5, 8
47 | self.done = False
48 | return self.get_obs()
49 |
50 | def set_xy(self, x, y):
51 | self.x = x
52 | self.y = y
53 | return self.get_obs()
54 |
55 | def compute_reward(self, prev_x, prev_y, action):
56 | info = {'is_success': False}
57 | done = False
58 | loc = self.map[prev_x, prev_y]
59 | assert loc > 0
60 | if loc < 2:
61 | reward = 0
62 | else:
63 | if action == self.default_action[loc-2]:
64 | reward = 1
65 | elif action == self.inverse_default_action[loc-2]:
66 | reward = -1
67 | else:
68 | reward = 0
69 | return reward, done, info
70 |
71 | def step(self, action):
72 | #assert not self.done
73 | nx, ny = self.x + self.dx[action], self.y + self.dy[action]
74 | info = {'is_success': False}
75 | #before = self.get_obs().argmax()
76 | if self.map[nx, ny]:
77 | reward, done, info = self.compute_reward(self.x, self.y, action)
78 | self.x, self.y = nx, ny
79 | else:
80 | #dis = (self.goal[0]-self.x)**2 + (self.goal[1]-self.y)**2
81 | #reward = -np.sqrt(dis)
82 | reward = 0
83 | done = False
84 | return self.get_obs(), reward, done, info
85 |
86 | def restore(self, obs):
87 | obs = obs.argmax()
88 | self.x = obs//self.n
89 | self.y = obs % self.n
90 |
91 | def inv_action(self, state, prev_state):
92 | x, y = state // self.n, state % self.n
93 | px, py = prev_state // self.n, prev_state % self.n
94 | dx = x - px
95 | dy = y - py
96 | if dx == 1 and dy == 0:
97 | return 1
98 | elif dx == -1 and dy == 0:
99 | return 3
100 | elif dy == 1 and dx == 0:
101 | return 0
102 | else:
103 | return 2
104 |
105 |
106 | class FourRoom1(FourRoom):
107 | def __init__(self, seed=None, *args, **kwargs):
108 | FourRoom.__init__(self, *args, **kwargs)
109 | self.n = 11
110 | self.map = np.array([
111 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
112 | 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0,
113 | 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0,
114 | 0, 1, 1, 4, 3, 3, 3, 3, 3, 1, 0,
115 | 0, 1, 1, 4, 1, 0, 1, 1, 2, 1, 0,
116 | 0, 0, 0, 4, 0, 0, 0, 0, 2, 0, 0,
117 | 0, 1, 1, 4, 1, 0, 0, 0, 2, 0, 0,
118 | 0, 1, 1, 4, 1, 0, 1, 1, 2, 1, 0,
119 | 0, 1, 1, 5, 5, 5, 5, 5, 2, 1, 0,
120 | 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0,
121 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
122 | ]).reshape((self.n, self.n))
123 | self.init()
124 |
125 | def init(self):
126 | self.observation_space = spaces.Box(low=0, high=1, shape=(self.n*self.n,), dtype=np.float32)
127 | self.observation_space.n = self.n
128 | self.dx = [0, 1, 0, -1, 0, 2, 0, -2]
129 | self.dy = [1, 0, -1, 0, 2, 0, -2, 0]
130 | self.default_action = np.array([3, 2, 1, 0, 7, 6, 5, 4])
131 | self.inverse_default_action = np.array([1, 0, 3, 2, 5, 4, 7, 6])
132 | self.action_space = spaces.Discrete(len(self.dx))
133 | self.reset()
134 |
135 | def compute_reward(self, prev_x, prev_y, action):
136 | info = {'is_success': False}
137 | done = False
138 | loc = self.map[prev_x, prev_y]
139 | assert loc > 0
140 | if loc < 2:
141 | reward = 0
142 | else:
143 | if action == self.default_action[loc-2]:
144 | reward = 1
145 | elif action == self.default_action[loc+2]:
146 | reward = 2
147 | elif action == self.inverse_default_action[loc-2]:
148 | reward = -1
149 | elif action == self.inverse_default_action[loc+2]:
150 | reward = -2
151 | else:
152 | reward = 0
153 | return reward, done, info
154 |
155 | def step(self, action):
156 | #assert not self.done
157 | nx, ny = max(0, self.x + self.dx[action]), max(0, self.y + self.dy[action])
158 | nx, ny = min(self.n-1, nx), min(self.n-1, ny)
159 | info = {'is_success': False}
160 | #before = self.get_obs().argmax()
161 | if self.map[nx, ny]:
162 | reward, done, info = self.compute_reward(self.x, self.y, action)
163 | self.x, self.y = nx, ny
164 | else:
165 | #dis = (self.goal[0]-self.x)**2 + (self.goal[1]-self.y)**2
166 | #reward = -np.sqrt(dis)
167 | reward = 0
168 | done = False
169 | return self.get_obs(), reward, done, info
170 |
171 |
--------------------------------------------------------------------------------
/models/ppo_models.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | from utils.math import *
4 | from torch.distributions import Normal
5 | from models.sac_models import weights_init_
6 | from torch.distributions.categorical import Categorical
7 |
8 | class Value(nn.Module):
9 | def __init__(self, state_dim, hidden_size=(256, 256), activation='tanh'):
10 | super().__init__()
11 | if activation == 'tanh':
12 | self.activation = torch.tanh
13 | elif activation == 'relu':
14 | self.activation = torch.relu
15 | elif activation == 'sigmoid':
16 | self.activation = torch.sigmoid
17 |
18 | self.affine_layers = nn.ModuleList()
19 | last_dim = state_dim
20 | for nh in hidden_size:
21 | self.affine_layers.append(nn.Linear(last_dim, nh))
22 | last_dim = nh
23 |
24 | self.value_head = nn.Linear(last_dim, 1)
25 | self.apply(weights_init_)
26 |
27 | def forward(self, x):
28 | for affine in self.affine_layers:
29 | x = self.activation(affine(x))
30 |
31 | value = self.value_head(x)
32 | return value
33 |
34 |
35 | class Policy(nn.Module):
36 | def __init__(self, state_dim, action_dim, hidden_size=(256, 256, 256), activation='tanh', log_std=0):
37 | super().__init__()
38 | self.is_disc_action = False
39 | if activation == 'tanh':
40 | self.activation = torch.tanh
41 | elif activation == 'relu':
42 | self.activation = torch.relu
43 | elif activation == 'sigmoid':
44 | self.activation = torch.sigmoid
45 |
46 | self.affine_layers = nn.ModuleList()
47 | last_dim = state_dim
48 | for nh in hidden_size:
49 | self.affine_layers.append(nn.Linear(last_dim, nh))
50 | last_dim = nh
51 |
52 | self.action_mean = nn.Linear(last_dim, action_dim)
53 | self.action_mean.weight.data.mul_(0.1)
54 | self.action_mean.bias.data.mul_(0.0)
55 |
56 | self.action_log_std = nn.Parameter(torch.ones(1, action_dim) * log_std)
57 |
58 | def forward(self, x):
59 | for affine in self.affine_layers:
60 | x = self.activation(affine(x))
61 |
62 | action_mean = self.action_mean(x)
63 | action_log_std = self.action_log_std.expand_as(action_mean)
64 | action_std = torch.exp(action_log_std)
65 |
66 | return action_mean, action_log_std, action_std
67 |
68 | def select_action(self, x):
69 | action_mean, _, action_std = self.forward(x)
70 | normal = Normal(action_mean, action_std)
71 | action = normal.sample()
72 | return action
73 |
74 | def get_kl(self, x):
75 | mean1, log_std1, std1 = self.forward(x)
76 |
77 | mean0 = mean1.detach()
78 | log_std0 = log_std1.detach()
79 | std0 = std1.detach()
80 | kl = log_std1 - log_std0 + (std0.pow(2) + (mean0 - mean1).pow(2)) / (2.0 * std1.pow(2)) - 0.5
81 | return kl.sum(1, keepdim=True)
82 |
83 | def get_log_prob(self, x, actions):
84 | action_mean, action_log_std, action_std = self.forward(x)
85 | return normal_log_density(actions, action_mean, action_log_std, action_std)
86 |
87 | def get_log_prob_states(self, x):
88 | action_mean, action_log_std, action_std = self.forward(x)
89 | normal = Normal(action_mean, action_std)
90 | action = normal.rsample()
91 | entropy = normal.entropy().sum(-1)
92 | return action, entropy
93 |
94 | def get_fim(self, x):
95 | mean, _, _ = self.forward(x)
96 | cov_inv = self.action_log_std.exp().pow(-2).squeeze(0).repeat(x.size(0))
97 | param_count = 0
98 | std_index = 0
99 | id = 0
100 | for name, param in self.named_parameters():
101 | if name == "action_log_std":
102 | std_id = id
103 | std_index = param_count
104 | param_count += param.view(-1).shape[0]
105 | id += 1
106 | return cov_inv.detach(), mean, {'std_id': std_id, 'std_index': std_index}
107 |
108 | def get_entropy(self, x):
109 | mean, action_log_std, action_std = self.forward(x)
110 | dist = Normal(mean, action_std)
111 | entropy = dist.entropy().sum(-1).unsqueeze(-1)
112 | return entropy
113 |
114 |
115 | class DiscretePolicy(nn.Module):
116 | def __init__(self, state_dim, action_num, hidden_size=(128, 128), activation='tanh'):
117 | super().__init__()
118 | self.is_disc_action = True
119 | if activation == 'tanh':
120 | self.activation = torch.tanh
121 | elif activation == 'relu':
122 | self.activation = torch.relu
123 | elif activation == 'sigmoid':
124 | self.activation = torch.sigmoid
125 |
126 | self.affine_layers = nn.ModuleList()
127 | last_dim = state_dim
128 | for nh in hidden_size:
129 | self.affine_layers.append(nn.Linear(last_dim, nh))
130 | last_dim = nh
131 |
132 | self.action_head = nn.Linear(last_dim, action_num)
133 | self.action_head.weight.data.mul_(0.1)
134 | self.action_head.bias.data.mul_(0.0)
135 |
136 | def forward(self, x):
137 | for affine in self.affine_layers:
138 | x = self.activation(affine(x))
139 |
140 | action_prob = torch.softmax(self.action_head(x), dim=1)
141 | action = action_prob.multinomial(1)
142 | return action
143 |
144 | def log_prob(self, x):
145 | for affine in self.affine_layers:
146 | x = self.activation(affine(x))
147 | action_prob = torch.softmax(self.action_head(x), dim=1)
148 | return action_prob
149 |
150 | def select_action(self, x):
151 | action = self.forward(x)
152 | return action
153 |
154 | def get_kl(self, x):
155 | action_prob1 = self.log_prob(x)
156 | action_prob0 = action_prob1.detach()
157 | kl = action_prob0 * (torch.log(action_prob0) - torch.log(action_prob1))
158 | return kl.sum(1, keepdim=True)
159 |
160 | def get_log_prob(self, x, actions):
161 | action_prob = self.log_prob(x)
162 | return torch.log(action_prob.gather(1, actions.long().unsqueeze(1)))
163 |
164 | def get_log_prob_states(self, x):
165 | action_prob = self.log_prob(x)
166 | action = action_prob.multinomial(1)
167 | return action, torch.log(action_prob.gather(1, action.long().unsqueeze(1)))
168 |
169 | def get_fim(self, x):
170 | action_prob = self.log_prob(x)
171 | M = action_prob.pow(-1).view(-1).detach()
172 | return M, action_prob, {}
173 |
174 | def get_entropy(self, x):
175 | action_prob = self.log_prob(x)
176 | dist = Categorical(action_prob)
177 | entropy = dist.entropy()
178 | return entropy
--------------------------------------------------------------------------------
/expert/train_ppo.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import gym
3 | import os
4 | import sys
5 | import pickle
6 | import time
7 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
8 |
9 | from utils import *
10 | from models.ppo_models import Policy, Value, DiscretePolicy
11 | from core.ppo import ppo_step
12 | from core.common import estimate_advantages
13 | from core.agent import Agent
14 | import envs
15 | from tensorboardX import SummaryWriter
16 | import datetime
17 |
18 | parser = argparse.ArgumentParser(description='PyTorch PPO example')
19 | parser.add_argument('--env-name', default="Hopper-v2", metavar='G',
20 | help='name of the environment to run')
21 | parser.add_argument('--model-path', metavar='G',
22 | help='path of pre-trained model')
23 | parser.add_argument('--resume', type=bool, default=False)
24 | parser.add_argument('--render', action='store_true', default=False,
25 | help='render the environment')
26 | parser.add_argument('--log-std', type=float, default=-0.0, metavar='G',
27 | help='log std for the policy (default: -0.0)')
28 | parser.add_argument('--gamma', type=float, default=0.99, metavar='G',
29 | help='discount factor (default: 0.99)')
30 | parser.add_argument('--tau', type=float, default=0.95, metavar='G',
31 | help='gae (default: 0.95)')
32 | parser.add_argument('--l2-reg', type=float, default=1e-3, metavar='G',
33 | help='l2 regularization regression (default: 1e-3)')
34 | parser.add_argument('--learning-rate', type=float, default=3e-4, metavar='G',
35 | help='learning rate (default: 3e-4)')
36 | parser.add_argument('--clip-epsilon', type=float, default=0.2, metavar='N',
37 | help='clipping epsilon for PPO')
38 | parser.add_argument('--seed', type=int, default=1, metavar='N',
39 | help='random seed (default: 1)')
40 | parser.add_argument('--min-batch-size', type=int, default=2048, metavar='N',
41 | help='minimal batch size per PPO update (default: 2048)')
42 | parser.add_argument('--epoch', type=int, default=1000, metavar='N',
43 | help='maximal number of main iterations (default: 500)')
44 | parser.add_argument('--log-interval', type=int, default=1, metavar='N',
45 | help='interval between training status logs (default: 10)')
46 | parser.add_argument('--save-model-interval', type=int, default=50, metavar='N',
47 | help="interval between saving model (default: 0, means don't save)")
48 | parser.add_argument('--gpu-index', type=int, default=0, metavar='N')
49 | args = parser.parse_args()
50 |
51 | dtype = torch.float32
52 | torch.set_default_dtype(dtype)
53 | device = torch.device('cuda', index=args.gpu_index) if torch.cuda.is_available() else torch.device('cpu')
54 | if torch.cuda.is_available():
55 | torch.cuda.set_device(args.gpu_index)
56 |
57 | """environment"""
58 | env = gym.make(args.env_name)
59 | state_dim = env.observation_space.shape[0]
60 | is_disc_action = len(env.action_space.shape) == 0
61 | running_state = ZFilter((state_dim,), clip=5)
62 | save_path = '{}_ppo_{}'.format(datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"), args.env_name)
63 | writer = SummaryWriter(log_dir='runs/'+save_path)
64 | # running_reward = ZFilter((1,), demean=False, clip=10)
65 |
66 | """seeding"""
67 | np.random.seed(args.seed)
68 | torch.manual_seed(args.seed)
69 | env.seed(args.seed)
70 | """define actor and critic"""
71 | if args.resume == False:
72 | if is_disc_action:
73 | policy_net = DiscretePolicy(state_dim, env.action_space.n)
74 | else:
75 | policy_net = Policy(state_dim, env.action_space.shape[0], log_std=args.log_std)
76 | value_net = Value(state_dim)
77 | else:
78 | policy_net, value_net = pickle.load(open(args.model_path, "rb"))
79 | policy_net.to(device)
80 | value_net.to(device)
81 |
82 | optimizer_policy = torch.optim.Adam(policy_net.parameters(), lr=args.learning_rate)
83 | optimizer_value = torch.optim.Adam(value_net.parameters(), lr=args.learning_rate)
84 |
85 | # optimization epoch number and batch size for PPO
86 | optim_epochs = 10
87 | optim_batch_size = 64
88 |
89 | """create agent"""
90 | agent = Agent(env, policy_net, device, render=args.render, running_state=running_state, update=True)
91 |
92 |
93 | def update_params(batch, i_iter):
94 | states = torch.from_numpy(np.stack(batch.state)).to(dtype).to(device)
95 | actions = torch.from_numpy(np.stack(batch.action)).to(dtype).to(device)
96 | rewards = torch.from_numpy(np.stack(batch.reward)).to(dtype).to(device)
97 | masks = torch.from_numpy(np.stack(batch.mask)).to(dtype).to(device)
98 | with torch.no_grad():
99 | values = value_net(states)
100 | fixed_log_probs = policy_net.get_log_prob(states, actions)
101 |
102 | """get advantage estimation from the trajectories"""
103 | advantages, returns = estimate_advantages(rewards, masks, values, args.gamma, args.tau, device)
104 |
105 | """perform mini-batch PPO update"""
106 | optim_iter_num = int(math.ceil(states.shape[0] / optim_batch_size))
107 | for _ in range(optim_epochs):
108 | perm = np.arange(states.shape[0])
109 | np.random.shuffle(perm)
110 | perm = LongTensor(perm).to(device)
111 |
112 | states, actions, returns, advantages, fixed_log_probs = \
113 | states[perm].clone(), actions[perm].clone(), returns[perm].clone(), advantages[perm].clone(), fixed_log_probs[perm].clone()
114 |
115 | for i in range(optim_iter_num):
116 | ind = slice(i * optim_batch_size, min((i + 1) * optim_batch_size, states.shape[0]))
117 | states_b, actions_b, advantages_b, returns_b, fixed_log_probs_b = \
118 | states[ind], actions[ind], advantages[ind], returns[ind], fixed_log_probs[ind]
119 |
120 | ppo_step(policy_net, value_net, optimizer_policy, optimizer_value, 1, states_b, actions_b, returns_b,
121 | advantages_b, fixed_log_probs_b, args.clip_epsilon, args.l2_reg)
122 |
123 |
124 | def main_loop():
125 | total_numsteps = 0
126 | for i_iter in range(args.epoch):
127 | """generate multiple trajectories that reach the minimum batch_size"""
128 | batch, log = agent.collect_samples(args.min_batch_size)
129 | total_numsteps += log['num_steps']
130 | t0 = time.time()
131 | update_params(batch, i_iter)
132 | t1 = time.time()
133 | if i_iter % args.log_interval == 0:
134 | print('{}\tT_sample {:.4f}\tT_update {:.4f}\tR_min {:.2f}\tR_max {:.2f}\tR_avg {:.2f}'.format(
135 | total_numsteps, log['num_steps'], t1-t0, log['min_reward'], log['max_reward'], log['avg_reward']))
136 | writer.add_scalar('reward/env', log['avg_reward'], total_numsteps)
137 |
138 |
139 | if args.save_model_interval > 0 and (i_iter) % args.save_model_interval == 0:
140 | to_device(torch.device('cpu'), policy_net, value_net)
141 | pickle.dump((policy_net, value_net),
142 | open(os.path.join(assets_dir(), 'learned_models/{}_ppo_run.p'.format(args.env_name)), 'wb'))
143 | to_device(device, policy_net, value_net)
144 |
145 | """clean up gpu memory"""
146 | torch.cuda.empty_cache()
147 |
148 | main_loop()
149 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import numpy as np
4 |
5 |
6 | def gen_dataset(expert_traj_raw, batch_size):
7 | length = len(expert_traj_raw)
8 | perm = np.arange(length)
9 | np.random.shuffle(perm)
10 | expert_traj = expert_traj_raw[perm].copy()
11 | while True:
12 | for i in range(len(expert_traj) // batch_size):
13 | yield expert_traj[i * batch_size:(i + 1) * batch_size]
14 |
15 | def create_log_gaussian(mean, log_std, t):
16 | quadratic = -((0.5 * (t - mean) / (log_std.exp())).pow(2))
17 | l = mean.shape
18 | log_z = log_std
19 | z = l[-1] * math.log(2 * math.pi)
20 | log_p = quadratic.sum(dim=-1) - log_z.sum(dim=-1) - 0.5 * z
21 | return log_p
22 |
23 |
24 | def logsumexp(inputs, dim=None, keepdim=False):
25 | if dim is None:
26 | inputs = inputs.view(-1)
27 | dim = 0
28 | s, _ = torch.max(inputs, dim=dim, keepdim=True)
29 | outputs = s + (inputs - s).exp().sum(dim=dim, keepdim=True).log()
30 | if not keepdim:
31 | outputs = outputs.squeeze(dim)
32 | return outputs
33 |
34 |
35 | def soft_update(target, source, tau):
36 | for target_param, param in zip(target.parameters(), source.parameters()):
37 | target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau)
38 |
39 |
40 | def hard_update(target, source):
41 | for target_param, param in zip(target.parameters(), source.parameters()):
42 | target_param.data.copy_(param.data)
43 |
44 | def fig2rgb_array(fig, expand=False):
45 | fig.canvas.draw()
46 | buf = fig.canvas.tostring_rgb()
47 | ncols, nrows = fig.canvas.get_width_height()
48 | shape = (nrows, ncols, 3) if not expand else (1, nrows, ncols, 3)
49 | return np.fromstring(buf, dtype=np.uint8).reshape(shape)
50 |
51 | from scipy import linalg
52 |
53 | def compute_precision_cholesky(covariances, covariance_type):
54 | """Compute the Cholesky decomposition of the precisions.
55 | Parameters
56 | ----------
57 | covariances : array-like
58 | The covariance matrix of the current components.
59 | The shape depends of the covariance_type.
60 | covariance_type : {'full', 'tied', 'diag', 'spherical'}
61 | The type of precision matrices.
62 | Returns
63 | -------
64 | precisions_cholesky : array-like
65 | The cholesky decomposition of sample precisions of the current
66 | components. The shape depends of the covariance_type.
67 | """
68 | estimate_precision_error_message = (
69 | "Fitting the mixture model failed because some components have "
70 | "ill-defined empirical covariance (for instance caused by singleton "
71 | "or collapsed samples). Try to decrease the number of components, "
72 | "or increase reg_covar.")
73 |
74 | if covariance_type in 'full':
75 | n_components, n_features, _ = covariances.shape
76 | precisions_chol = np.empty((n_components, n_features, n_features))
77 | for k, covariance in enumerate(covariances):
78 | try:
79 | cov_chol = linalg.cholesky(covariance, lower=True)
80 | except linalg.LinAlgError:
81 | raise ValueError(estimate_precision_error_message)
82 | precisions_chol[k] = linalg.solve_triangular(cov_chol,
83 | np.eye(n_features),
84 | lower=True).T
85 | elif covariance_type == 'tied':
86 | _, n_features = covariances.shape
87 | try:
88 | cov_chol = linalg.cholesky(covariances, lower=True)
89 | except linalg.LinAlgError:
90 | raise ValueError(estimate_precision_error_message)
91 | precisions_chol = linalg.solve_triangular(cov_chol, np.eye(n_features),
92 | lower=True).T
93 | else:
94 | if np.any(np.less_equal(covariances, 0.0)):
95 | raise ValueError(estimate_precision_error_message)
96 | precisions_chol = 1. / np.sqrt(covariances)
97 | return precisions_chol
98 |
99 | def process_expert_traj(expert_traj_raw):
100 | expert_traj = []
101 | for i in range(len(expert_traj_raw)):
102 | for j in range(len(expert_traj_raw[i])):
103 | expert_traj.append(expert_traj_raw[i][j])
104 | expert_traj = np.stack(expert_traj)
105 | #print('here', expert_traj.shape)
106 | return expert_traj
107 |
108 | def generate_pairs(expert_traj_raw, state_dim, size_per_traj, max_step=6):
109 | '''
110 | generate state pairs (s, s_t)
111 | note that s_t can be multi-step future (controlled by max_step)
112 | '''
113 | pairs = []
114 | for i in range(len(expert_traj_raw)):
115 | traj = expert_traj_raw[i]
116 | if len(traj) == 0:
117 | continue
118 | start = np.random.randint(0, len(traj), size_per_traj)
119 | step = np.random.randint(1, max_step+1, size_per_traj)
120 | end = np.minimum(start+step, len(traj)-1)
121 | start_state, end_state = traj[start], traj[end]
122 |
123 | final_dim = state_dim*2
124 | state_pairs = np.concatenate([start_state[:, :state_dim], end_state[:, :state_dim]], axis=1)
125 |
126 | pairs.append(state_pairs)
127 | pairs = np.stack(pairs).reshape(-1, final_dim)
128 | return pairs
129 |
130 | def generate_tuples(expert_traj_raw, state_dim):
131 | '''
132 | generate transition tuples (s, s', a) for training
133 | '''
134 | expert_traj = []
135 | for i in range(len(expert_traj_raw)):
136 | for j in range(len(expert_traj_raw[i])):
137 | if j < len(expert_traj_raw[i])-1:
138 | state_action = expert_traj_raw[i][j]
139 | next_state = expert_traj_raw[i][j+1][:state_dim]
140 | transitions = np.concatenate([state_action[:state_dim], next_state, state_action[state_dim:]], axis=-1)
141 | expert_traj.append(transitions)
142 | expert_traj = np.stack(expert_traj)
143 | return expert_traj
144 |
145 | def normalize_expert_traj(running_state, expert_traj, state_dim):
146 | '''
147 | normalize the demonstration data by the state normalizer
148 | '''
149 | traj = []
150 | for i in range(len(expert_traj)):
151 | state = expert_traj[i, :state_dim]
152 | rest = expert_traj[i, state_dim:]
153 | state = running_state(state, update=False)
154 | tuple = np.concatenate([state, rest], axis=-1)
155 | traj.append(tuple)
156 | traj = np.stack(traj)
157 | return traj
158 |
159 | def normalize_states(running_state, state_pairs, state_dim):
160 | '''
161 | normalize the state pairs/tuples by state normalizer
162 | '''
163 | traj = []
164 | for i in range(len(state_pairs)):
165 | state = state_pairs[i, :state_dim]
166 | next_state = state_pairs[i, state_dim:state_dim*2]
167 | rest = state_pairs[i, state_dim*2:]
168 | state = running_state(state, update=False)
169 | next_state = running_state(next_state, update=False)
170 | tuple = np.concatenate([state, next_state, rest], axis=-1)
171 | traj.append(tuple)
172 | traj = np.stack(traj)
173 | return traj
174 |
175 |
176 | def adjust_lr(optimizer, scale):
177 | print('=========adjust learning rate================')
178 | for param_group in optimizer.param_groups:
179 | param_group['lr'] = param_group['lr'] / scale
--------------------------------------------------------------------------------
/agents/sac_agent.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn.functional as F
4 | from torch.optim import Adam
5 | from utils.utils import soft_update, hard_update
6 | from models.sac_models import GaussianPolicy, QNetwork, DeterministicPolicy
7 | import numpy as np
8 | from itertools import count
9 | import pickle
10 | from utils import *
11 |
12 | class SAC_agent(object):
13 | def __init__(self, env, num_inputs, action_space, args, running_state=None):
14 | self.gamma = args.gamma
15 | self.tau = args.tau
16 | self.alpha = args.alpha
17 | self.args = args
18 | self.env = env
19 | self.running_state = running_state
20 | self.policy_type = args.policy
21 | self.target_update_interval = args.target_update_interval
22 | self.automatic_entropy_tuning = args.automatic_entropy_tuning
23 |
24 | self.device = args.device
25 |
26 | self.critic = QNetwork(num_inputs, action_space.shape[0], args.hidden_size).to(device=self.device)
27 | self.critic_optim = Adam(self.critic.parameters(), lr=args.lr)
28 |
29 | self.critic_target = QNetwork(num_inputs, action_space.shape[0], args.hidden_size).to(self.device)
30 | hard_update(self.critic_target, self.critic)
31 |
32 | if self.policy_type == "Gaussian":
33 | # Target Entropy = −dim(A) (e.g. , -6 for HalfCheetah-v2) as given in the paper
34 | if self.automatic_entropy_tuning == True:
35 | self.target_entropy = -torch.prod(torch.Tensor(action_space.shape).to(self.device)).item()
36 | self.log_alpha = torch.zeros(1, requires_grad=True, device=self.device)
37 | self.alpha_optim = Adam([self.log_alpha], lr=args.lr)
38 |
39 | self.policy = GaussianPolicy(num_inputs, action_space.shape[0], args.hidden_size, action_space).to(self.device)
40 | self.policy_optim = Adam(self.policy.parameters(), lr=args.lr)
41 |
42 | else:
43 | self.alpha = 0
44 | self.automatic_entropy_tuning = False
45 | self.policy = DeterministicPolicy(num_inputs, action_space.shape[0], args.hidden_size, action_space).to(self.device)
46 | self.policy_optim = Adam(self.policy.parameters(), lr=args.lr)
47 |
48 | def select_action(self, state, eval=False):
49 | state = torch.FloatTensor(state).to(self.device).unsqueeze(0)
50 | if eval == False:
51 | action, _, _ = self.policy.sample(state)
52 | else:
53 | _, _, action = self.policy.sample(state)
54 | return action.detach().cpu().numpy()[0]
55 |
56 | def update_parameters(self, memory, batch_size, updates):
57 | # Sample a batch from memory
58 | state_batch, action_batch, reward_batch, next_state_batch, mask_batch = memory.sample(batch_size=batch_size)
59 |
60 | state_batch = torch.FloatTensor(state_batch).to(self.device)
61 | next_state_batch = torch.FloatTensor(next_state_batch).to(self.device)
62 | action_batch = torch.FloatTensor(action_batch).to(self.device)
63 | reward_batch = torch.FloatTensor(reward_batch).to(self.device).unsqueeze(1)
64 | mask_batch = torch.FloatTensor(mask_batch).to(self.device).unsqueeze(1)
65 |
66 | with torch.no_grad():
67 | next_state_action, next_state_log_pi, _ = self.policy.sample(next_state_batch)
68 | qf1_next_target, qf2_next_target = self.critic_target(next_state_batch, next_state_action)
69 | min_qf_next_target = torch.min(qf1_next_target, qf2_next_target) - self.alpha * next_state_log_pi
70 | next_q_value = reward_batch + mask_batch * self.gamma * (min_qf_next_target)
71 |
72 | qf1, qf2 = self.critic(state_batch, action_batch) # Two Q-functions to mitigate positive bias in the policy improvement step
73 | qf1_loss = F.mse_loss(qf1, next_q_value) # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]
74 | qf2_loss = F.mse_loss(qf2, next_q_value) # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]
75 |
76 | pi, log_pi, _ = self.policy.sample(state_batch)
77 |
78 | qf1_pi, qf2_pi = self.critic(state_batch, pi)
79 | min_qf_pi = torch.min(qf1_pi, qf2_pi)
80 |
81 | policy_loss = ((self.alpha * log_pi) - min_qf_pi).mean() # Jπ = 𝔼st∼D,εt∼N[α * logπ(f(εt;st)|st) − Q(st,f(εt;st))]
82 |
83 | self.critic_optim.zero_grad()
84 | qf1_loss.backward()
85 | self.critic_optim.step()
86 |
87 | self.critic_optim.zero_grad()
88 | qf2_loss.backward()
89 | self.critic_optim.step()
90 |
91 | self.policy_optim.zero_grad()
92 | policy_loss.backward()
93 | self.policy_optim.step()
94 |
95 | if self.automatic_entropy_tuning:
96 | alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean()
97 |
98 | self.alpha_optim.zero_grad()
99 | alpha_loss.backward()
100 | self.alpha_optim.step()
101 |
102 | self.alpha = self.log_alpha.exp()
103 | alpha_tlogs = self.alpha.clone() # For TensorboardX logs
104 | else:
105 | alpha_loss = torch.tensor(0.).to(self.device)
106 | alpha_tlogs = torch.tensor(self.alpha) # For TensorboardX logs
107 |
108 |
109 | if updates % self.target_update_interval == 0:
110 | soft_update(self.critic_target, self.critic, self.tau)
111 |
112 | return qf1_loss.item(), qf2_loss.item(), policy_loss.item(), alpha_loss.item(), alpha_tlogs.item()
113 |
114 | # Save model parameters
115 | def save_model(self, env_name, episode, actor_path=None, critic_path=None):
116 |
117 | if actor_path is None:
118 | actor_path = "assets/learned_models/sac_actor_{}_{}".format(env_name, str(episode))
119 | if critic_path is None:
120 | critic_path = "assets/learned_models/sac_critic_{}_{}".format(env_name, str(episode))
121 | print('Saving models to {} and {}'.format(actor_path, critic_path))
122 | torch.save(self.policy.state_dict(), actor_path)
123 | torch.save(self.critic.state_dict(), critic_path)
124 |
125 | # Load model parameters
126 | def load_model(self, actor_path, critic_path):
127 | print('Loading models from {} and {}'.format(actor_path, critic_path))
128 | if actor_path is not None:
129 | self.policy.load_state_dict(torch.load(actor_path))
130 | if critic_path is not None:
131 | self.critic.load_state_dict(torch.load(critic_path))
132 | self.critic_target.load_state_dict(torch.load(critic_path))
133 |
134 | def save_expert_traj(self, max_step=100000):
135 | print('save traj from a pretrained expert')
136 | num_steps = 0
137 | expert_trajs = []
138 | for i_episode in count():
139 | expert_traj = []
140 | state = self.env.reset()
141 | if self.running_state is not None:
142 | state = self.running_state(state)
143 | reward_episode = 0
144 |
145 | for t in range(10000):
146 | action = self.select_action(state, eval=True)
147 | next_state, reward, done, _ = self.env.step(action)
148 | if self.running_state is not None:
149 | next_state = self.running_state(next_state)
150 | reward_episode += reward
151 | num_steps += 1
152 | expert_traj.append(np.hstack([state, action]))
153 | if done:
154 | expert_traj = np.stack(expert_traj)
155 | expert_trajs.append(expert_traj)
156 | break
157 | state = next_state
158 | print('Episode {}\t reward: {:.2f}'.format(i_episode, reward_episode))
159 |
160 | if num_steps >= max_step:
161 | break
162 |
163 | reward_episode = int(reward_episode)
164 | if self.running_state is not None:
165 | pickle.dump((expert_trajs, self.running_state), open(os.path.join(assets_dir(), 'expert_traj/{}_expert_traj.p'.format(self.args.env_name)+'_'+str(reward_episode)), \
166 | 'wb'))
167 | else:
168 | pickle.dump(expert_trajs, open(os.path.join(assets_dir(), 'expert_traj/{}_expert_traj.p'.format(self.args.env_name)+'_'+str(reward_episode)), \
169 | 'wb'))
170 |
--------------------------------------------------------------------------------
/envs/mountaincar.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import numpy as np
4 |
5 | import gym
6 | from gym import spaces
7 | from gym.utils import seeding
8 |
9 | class Continuous_MountainCarEnv(gym.Env):
10 | metadata = {
11 | 'render.modes': ['human', 'rgb_array'],
12 | 'video.frames_per_second': 30
13 | }
14 |
15 | def __init__(self, goal_velocity = 0):
16 | self.min_action = -1.0
17 | self.max_action = 1.0
18 | self.min_position = -1.2
19 | self.max_position = 0.6
20 | self.max_speed = 0.07
21 | self.goal_position = 0.45 # was 0.5 in gym, 0.45 in Arnaud de Broissia's version
22 | self.goal_velocity = goal_velocity
23 | self.power = 0.001
24 | self.grav = 0.002
25 | print('power, grav', self.power, self.grav)
26 | self.low_state = np.array([self.min_position, -self.max_speed])
27 | self.high_state = np.array([self.max_position, self.max_speed])
28 | '''
29 | modified
30 | '''
31 | self.pos_precision = 2
32 | self.vel_precision = 3
33 | self.pos_delta = 10 ** -self.pos_precision
34 | self.vel_delta = 10 ** -self.vel_precision
35 | self.pos_range = np.round(
36 | np.linspace(self.min_position, self.max_position, np.round((self.max_position - self.min_position) \
37 | / self.pos_delta + 1)),
38 | self.pos_precision)
39 | self.vel_range = np.round(
40 | np.linspace(-self.max_speed, self.max_speed, np.round((2 * self.max_speed) / self.vel_delta + 1)),
41 | self.vel_precision)
42 |
43 | self.test_pos_range = np.round(
44 | np.linspace(-0.6, -0.4, np.round(0.2 / self.pos_delta + 1)),
45 | self.pos_precision)
46 | self.test_vel_range = np.round(
47 | np.linspace(-0.03, 0.03, np.round(0.06 / self.vel_delta + 1)),
48 | self.vel_precision)
49 |
50 | self.n_states = len(self.pos_range) * len(self.vel_range)
51 | self.n_test_states = len(self.test_pos_range) * len(self.test_vel_range)
52 |
53 | self.viewer = None
54 |
55 | self.action_space = spaces.Box(low=self.min_action, high=self.max_action,
56 | shape=(1,), dtype=np.float32)
57 | self.observation_space = spaces.Box(low=self.low_state, high=self.high_state,
58 | dtype=np.float32)
59 |
60 | #self.seed()
61 | self.reset()
62 |
63 | def seed(self, seed=None):
64 | self.np_random, seed = seeding.np_random(seed)
65 | return [seed]
66 |
67 | def get_id(self, state):
68 | if state is None:
69 | return self.n_states
70 | pos, vel = state
71 | if np.round(pos, self.pos_precision) > self.max_position:
72 | return self.n_states
73 |
74 | pos_id = np.where(self.pos_range == np.round(pos, self.pos_precision))[0][0]
75 | vel_id = np.where(self.vel_range == np.round(vel, self.vel_precision))[0][0]
76 | return pos_id * len(self.vel_range) + vel_id
77 |
78 | def visualize(self, dist):
79 | p = self.pos_range
80 | v = self.vel_range
81 | plot_V = np.zeros([len(p), len(v)])
82 |
83 | for i in range(len(p)):
84 | for j in range(len(v)):
85 | plot_V[i, j] = -dist[self.get_id((p[i], v[j]))]
86 |
87 | from matplotlib import pyplot as plt
88 | p = np.array(p)
89 | v = np.array(v)
90 | plot_V = plot_V.T
91 | extent = [np.amin(p), np.amax(p), np.amax(v), np.amin(v)]
92 | plt.imshow(plot_V, extent=extent, aspect='auto', cmap='hot', interpolation='nearest')
93 | plt.show()
94 |
95 | def step(self, action):
96 |
97 | position = self.state[0]
98 | velocity = self.state[1]
99 | force = min(max(action[0], -1.0), 1.0)
100 |
101 | velocity += force*self.power -self.grav * math.cos(3*position)
102 | if (velocity > self.max_speed): velocity = self.max_speed
103 | if (velocity < -self.max_speed): velocity = -self.max_speed
104 | position += velocity
105 | if (position > self.max_position): position = self.max_position
106 | if (position < self.min_position): position = self.min_position
107 | if (position==self.min_position and velocity<0): velocity = 0
108 |
109 | done = bool(position >= self.goal_position and velocity >= self.goal_velocity)
110 |
111 | reward = 0
112 | if done:
113 | reward = 100.0
114 | reward-= math.pow(action[0],2)*0.1
115 |
116 | self.state = np.array([position, velocity])
117 | return self.state, reward, done, {}
118 |
119 | def vis_step(self, state, action):
120 |
121 | position = state[0]
122 | velocity = state[1]
123 | force = min(max(action[0], -1.0), 1.0)
124 |
125 | velocity += force*self.power -0.0025 * math.cos(3*position)
126 | if (velocity > self.max_speed): velocity = self.max_speed
127 | if (velocity < -self.max_speed): velocity = -self.max_speed
128 | position += velocity
129 | if (position > self.max_position): position = self.max_position
130 | if (position < self.min_position): position = self.min_position
131 | if (position==self.min_position and velocity<0): velocity = 0
132 |
133 | done = bool(position >= self.goal_position and velocity >= self.goal_velocity)
134 |
135 | reward = 0
136 | if done:
137 | reward = 100.0
138 | reward-= math.pow(action[0],2)*0.1
139 |
140 | next_state = np.array([position, velocity])
141 | return next_state, reward, done, {}
142 |
143 | def reset(self):
144 | self.state = np.array([-0.5, 0])
145 | return np.array(self.state)
146 |
147 | # def get_state(self):
148 | # return self.state
149 |
150 | def _height(self, xs):
151 | return np.sin(3 * xs)*.45+.55
152 |
153 | def render(self, mode='human'):
154 | screen_width = 600
155 | screen_height = 400
156 |
157 | world_width = self.max_position - self.min_position
158 | scale = screen_width/world_width
159 | carwidth=40
160 | carheight=20
161 |
162 |
163 | if self.viewer is None:
164 | from gym.envs.classic_control import rendering
165 | self.viewer = rendering.Viewer(screen_width, screen_height)
166 | xs = np.linspace(self.min_position, self.max_position, 100)
167 | ys = self._height(xs)
168 | xys = list(zip((xs-self.min_position)*scale, ys*scale))
169 |
170 | self.track = rendering.make_polyline(xys)
171 | self.track.set_linewidth(4)
172 | self.viewer.add_geom(self.track)
173 |
174 | clearance = 10
175 |
176 | l,r,t,b = -carwidth/2, carwidth/2, carheight, 0
177 | car = rendering.FilledPolygon([(l,b), (l,t), (r,t), (r,b)])
178 | car.add_attr(rendering.Transform(translation=(0, clearance)))
179 | self.cartrans = rendering.Transform()
180 | car.add_attr(self.cartrans)
181 | self.viewer.add_geom(car)
182 | frontwheel = rendering.make_circle(carheight/2.5)
183 | frontwheel.set_color(.5, .5, .5)
184 | frontwheel.add_attr(rendering.Transform(translation=(carwidth/4,clearance)))
185 | frontwheel.add_attr(self.cartrans)
186 | self.viewer.add_geom(frontwheel)
187 | backwheel = rendering.make_circle(carheight/2.5)
188 | backwheel.add_attr(rendering.Transform(translation=(-carwidth/4,clearance)))
189 | backwheel.add_attr(self.cartrans)
190 | backwheel.set_color(.5, .5, .5)
191 | self.viewer.add_geom(backwheel)
192 | flagx = (self.goal_position-self.min_position)*scale
193 | flagy1 = self._height(self.goal_position)*scale
194 | flagy2 = flagy1 + 50
195 | flagpole = rendering.Line((flagx, flagy1), (flagx, flagy2))
196 | self.viewer.add_geom(flagpole)
197 | flag = rendering.FilledPolygon([(flagx, flagy2), (flagx, flagy2-10), (flagx+25, flagy2-5)])
198 | flag.set_color(.8,.8,0)
199 | self.viewer.add_geom(flag)
200 |
201 | pos = self.state[0]
202 | self.cartrans.set_translation((pos-self.min_position)*scale, self._height(pos)*scale)
203 | self.cartrans.set_rotation(math.cos(3 * pos))
204 |
205 | return self.viewer.render(return_rgb_array = mode=='rgb_array')
206 |
207 | def close(self):
208 | if self.viewer:
209 | self.viewer.close()
210 | self.viewer = None
--------------------------------------------------------------------------------
/models/dynamics.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | from utils.tools import swish
5 | from torch.nn import functional as F
6 | import tqdm
7 | import math
8 |
9 | def weights_init_(m):
10 | if isinstance(m, nn.Linear):
11 | stdv = 1. / math.sqrt(m.weight.size(1))
12 | torch.nn.init.normal_(m.weight, std=1.0 / 2.0*stdv)
13 | torch.nn.init.constant_(m.bias, 0)
14 |
15 | class ProbModel(nn.Module):
16 | def __init__(self, in_features, out_features, device, hidden_dim=400, fix_sigma=True, use_diag=True):
17 | super().__init__()
18 | self.device = device
19 | self.in_features = in_features
20 | self.out_features = out_features
21 | #define layers of shared feature, mean and variance
22 | self.layers = 3
23 | self.mean_layers = 1
24 | self.var_layers = 1
25 |
26 | self.affine = nn.ModuleList()
27 | self.mean = nn.ModuleList()
28 |
29 | self.use_diag = use_diag
30 | if use_diag == True:
31 | self.scale = nn.Parameter(torch.Tensor((0.5,)), requires_grad=True)
32 | else:
33 | self.scale = nn.Linear(hidden_dim, out_features)
34 |
35 | for i in range(self.layers):
36 | if i == 0:
37 | self.affine.append(nn.Linear(in_features,hidden_dim))
38 | else:
39 | self.affine.append(nn.Linear(hidden_dim,hidden_dim))
40 |
41 | for i in range(self.mean_layers):
42 | if i == self.mean_layers - 1:
43 | self.mean.append(nn.Linear(hidden_dim, out_features))
44 | else:
45 | self.mean.append(nn.Linear(hidden_dim, hidden_dim))
46 |
47 | '''
48 | for i in range(self.var_layers):
49 | if i == self.var_layers - 1:
50 | self.logvar.append(nn.Linear(hidden_dim, out_features))
51 | else:
52 | self.logvar.append(nn.Linear(hidden_dim, hidden_dim))
53 | '''
54 |
55 | self.apply(weights_init_)
56 | self.fix_sigma = fix_sigma
57 | #self.logvar = torch.log(nn.Parameter(torch.ones(1, out_features, dtype=torch.float32) * 0.1, requires_grad=True))
58 | if self.fix_sigma is not True:
59 | self.max_logvar = nn.Parameter(torch.ones(1, out_features, dtype=torch.float32) / 2.0)
60 | self.min_logvar = nn.Parameter(-torch.ones(1, out_features, dtype=torch.float32) * 10.0)
61 |
62 | def forward(self, inputs, ret_logvar=False):
63 | for affine in self.affine:
64 | inputs = swish(affine(inputs))
65 |
66 | mean = inputs
67 | for i, mean_layer in enumerate(self.mean):
68 | if i == len(self.mean) - 1:
69 | mean = mean_layer(mean)
70 | else:
71 | mean = swish(mean_layer(mean))
72 | if self.use_diag == True:
73 | logvar = torch.log((self.scale.expand(inputs.shape[0], self.out_features))).to(self.device)
74 | else:
75 | logvar = self.scale(inputs)
76 | #logvar = logvar.expand(inputs.shape[0], self.out_features).to(self.device)
77 |
78 | '''
79 | logvar = inputs
80 | for i, var_layer in enumerate(self.logvar):
81 | if i == len(self.logvar) - 1:
82 | logvar = var_layer(logvar)
83 | else:
84 | logvar = swish(var_layer(logvar))
85 | '''
86 |
87 | if self.fix_sigma == False:
88 | logvar = self.max_logvar - F.softplus(self.max_logvar - logvar)
89 | logvar = self.min_logvar + F.softplus(logvar - self.min_logvar)
90 |
91 | if ret_logvar:
92 | return mean, logvar
93 | return mean, torch.exp(logvar)
94 |
95 | def set_sigma(self, sigma):
96 | self.fix_sigma = sigma
97 |
98 | class ForwardModel(nn.Module):
99 | def __init__(self, in_features, out_features, hidden_dim=128):
100 | super().__init__()
101 | self.in_features = in_features
102 | self.out_features = out_features
103 | self.affine_layers = nn.ModuleList()
104 | #self.bn_layers = nn.ModuleList()
105 | self.layers = 6
106 | self.first_layer = nn.Linear(self.in_features, hidden_dim)
107 | for i in range(self.layers):
108 | self.affine_layers.append(nn.Linear(hidden_dim, hidden_dim))
109 | self.relu = nn.ReLU()
110 |
111 | self.fc = nn.Linear(hidden_dim, out_features)
112 | self.apply(weights_init_)
113 |
114 |
115 | def forward(self, state, action):
116 | inputs = torch.cat((state, action), -1)
117 | last_output = self.relu(self.first_layer(inputs))
118 | for i, affine in enumerate(self.affine_layers):
119 | res = self.relu(affine(last_output))
120 | output = self.relu(last_output+res)
121 | last_output = output
122 | delta = self.fc(last_output)
123 | return delta
124 |
125 | def get_next_states(self, state, action):
126 | delta = self.forward(state, action)
127 | return state + delta
128 |
129 | def train(self, inputs_state, inputs_action, targets, optimizer, epoch=30, batch_size=256):
130 | #print('training model')
131 | idxs = np.arange(inputs_state.shape[0])
132 | np.random.shuffle(idxs)
133 | from tqdm import trange
134 | #epoch_range = trange(epoch, unit="epoch(s)", desc="Network training")
135 | num_batch = int(np.ceil(idxs.shape[-1] / batch_size))
136 |
137 | for _ in range(epoch):
138 | idxs = np.arange(inputs_state.shape[0])
139 | np.random.shuffle(idxs)
140 | for batch_num in range(num_batch):
141 | batch_idxs = idxs[batch_num * batch_size : (batch_num + 1) * batch_size]
142 | train_in_states = inputs_state[batch_idxs].float()
143 | train_in_actions = inputs_action[batch_idxs].float()
144 | train_targ = targets[batch_idxs].float()
145 |
146 | mean = self.forward(train_in_states, train_in_actions)
147 | train_losses = ((mean - train_targ) ** 2).mean()
148 | optimizer.zero_grad()
149 | train_losses.backward()
150 | optimizer.step()
151 |
152 | mean = self.forward(inputs_state, inputs_action)
153 | mse_losses = ((mean - targets) ** 2).mean(-1).mean(-1)
154 | print('forward model mse loss', mse_losses.detach().cpu().numpy())
155 | return mse_losses.detach().cpu().numpy()
156 |
157 | class InverseModel(nn.Module):
158 | def __init__(self, in_features, out_features, hidden_dim=128):
159 | super().__init__()
160 | self.in_features = in_features
161 | self.out_features = out_features
162 | self.affine_layers = nn.ModuleList()
163 | #self.bn_layers = nn.ModuleList()
164 | self.layers = 6
165 | self.first_layer = nn.Linear(self.in_features, hidden_dim)
166 | for i in range(self.layers):
167 | self.affine_layers.append(nn.Linear(hidden_dim, hidden_dim))
168 | self.relu = nn.ReLU()
169 |
170 | self.final = nn.Linear(hidden_dim, out_features)
171 | self.apply(weights_init_)
172 |
173 |
174 | def forward(self, state, next_state):
175 | inputs = torch.cat((state, next_state), -1)
176 | last_output = self.relu(self.first_layer(inputs))
177 | for i, affine in enumerate(self.affine_layers):
178 | res = self.relu(affine(last_output))
179 | output = self.relu(last_output+res)
180 | last_output = output
181 | action = self.final(last_output)
182 | return action
183 |
184 | def train(self, state, next_state, actions, optimizer, epoch=30, batch_size=256):
185 | idxs = np.arange(state.shape[0])
186 | np.random.shuffle(idxs)
187 | num_batch = int(np.ceil(idxs.shape[-1] / batch_size))
188 |
189 | for _ in range(epoch):
190 | idxs = np.arange(state.shape[0])
191 | np.random.shuffle(idxs)
192 | for batch_num in range(num_batch):
193 | batch_idxs = idxs[batch_num * batch_size : (batch_num + 1) * batch_size]
194 | states_train = state[batch_idxs].float()
195 | next_state_train = next_state[batch_idxs].float()
196 | actions_targ = actions[batch_idxs].float()
197 |
198 | res = self.forward(states_train, next_state_train)
199 | train_losses = ((res - actions_targ) ** 2).mean()
200 | optimizer.zero_grad()
201 | train_losses.backward()
202 | optimizer.step()
203 |
204 | actions_pred = self.forward(state, next_state)
205 | mse_losses = ((actions_pred - actions) ** 2).mean(-1).mean(-1)
206 | print('inverse model mse loss', mse_losses.detach().cpu().numpy())
207 | return mse_losses.detach().cpu().numpy()
--------------------------------------------------------------------------------
/envs/fourroom.py:
--------------------------------------------------------------------------------
1 | import gym
2 | import numpy as np
3 | from gym import spaces
4 |
5 |
6 | class FourRoom(gym.Env):
7 | def __init__(self, seed=None, goal_type='fix_goal'):
8 | self.n = 11
9 | self.map = np.array([
10 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
11 | 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0,
12 | 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0,
13 | 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,
14 | 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0,
15 | 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0,
16 | 0, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0,
17 | 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0,
18 | 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,
19 | 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0,
20 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
21 | ]).reshape((self.n, self.n))
22 | self.goal_type = goal_type
23 | self.goal = None
24 | self.init()
25 |
26 | def init(self):
27 | self.observation_space = spaces.Box(low=0, high=1, shape=(self.n*self.n,), dtype=np.float32)
28 | self.observation_space.n = self.n
29 | self.dx = [0, 1, 0, -1]
30 | self.dy = [1, 0, -1, 0]
31 | self.action_space = spaces.Discrete(len(self.dx))
32 | self.reset()
33 |
34 | def label2obs(self, x, y):
35 | a = np.zeros((self.n*self.n,))
36 | assert self.x < self.n and self.y < self.n
37 | a[x * self.n + y] = 1
38 | return a
39 |
40 | def get_obs(self):
41 | assert self.goal is not None
42 | return self.label2obs(self.x, self.y)
43 |
44 | def reset(self):
45 | '''
46 | condition = True
47 | while condition:
48 | self.x = np.random.randint(1, self.n)
49 | self.y = np.random.randint(1, self.n)
50 | condition = (self.map[self.x, self.y] == 0)
51 | '''
52 | self.x, self.y = 9, 9
53 | loc = np.where(self.map > 0.5)
54 | assert len(loc) == 2
55 | #if self.goal_type == 'random':
56 | # goal_idx = np.random.randint(len(loc[0]))
57 | if self.goal_type == 'fix_goal':
58 | goal_idx = 0
59 | else:
60 | raise NotImplementedError
61 | self.goal = loc[0][goal_idx], loc[1][goal_idx]
62 | self.done = False
63 | return self.get_obs()
64 |
65 | def set_xy(self, x, y):
66 | self.x = x
67 | self.y = y
68 | return self.get_obs()
69 |
70 | def step(self, action):
71 | #assert not self.done
72 | nx, ny = self.x + self.dx[action], self.y + self.dy[action]
73 | info = {'is_success': False}
74 | #before = self.get_obs().argmax()
75 | if self.map[nx, ny]:
76 | self.x, self.y = nx, ny
77 | #dis = (self.goal[0]-self.x)**2 + (self.goal[1]-self.y)**2
78 | #reward = -np.sqrt(dis)
79 | reward = -1
80 | done = False
81 | else:
82 | #dis = (self.goal[0]-self.x)**2 + (self.goal[1]-self.y)**2
83 | #reward = -np.sqrt(dis)
84 | reward = -1
85 | done = False
86 | if nx == self.goal[0] and ny == self.goal[1]:
87 | reward = 0
88 | info = {'is_success': True}
89 | done = self.done = True
90 | return self.get_obs(), reward, done, info
91 |
92 | def compute_reward(self, state, goal, info):
93 | state_obs = state.argmax(axis=1)
94 | goal_obs = goal.argmax(axis=1)
95 | reward = np.where(state_obs == goal_obs, 0, -1)
96 | return reward
97 |
98 | def restore(self, obs):
99 | obs = obs.argmax()
100 | self.x = obs//self.n
101 | self.y = obs % self.n
102 |
103 | def inv_action(self, state, prev_state):
104 | x, y = state // self.n, state % self.n
105 | px, py = prev_state // self.n, prev_state % self.n
106 | dx = x - px
107 | dy = y - py
108 | if dx == 1 and dy == 0:
109 | return 1
110 | elif dx == -1 and dy == 0:
111 | return 3
112 | elif dy == 1 and dx == 0:
113 | return 0
114 | else:
115 | return 2
116 |
117 |
118 | def bfs_dist(self, state, goal, order=True):
119 | #using bfs to search for shortest path
120 | visited = {key: False for key in range(self.n*self.n)}
121 | state_key = state.argmax()
122 | goal_key = goal.argmax()
123 | queue = []
124 | visited[state_key] = True
125 | queue.append(state_key)
126 | dist = [-np.inf] * (self.n*self.n)
127 | past = [-1] * (self.n*self.n)
128 | dist[state_key] = 0
129 |
130 | if order:
131 | act_order = range(4)
132 | else:
133 | act_order = range(3, 0, -1)
134 |
135 | while (queue):
136 | par = queue.pop(0)
137 | if par == goal_key:
138 | break
139 | x_par, y_par = par // self.n, par % self.n
140 | for action in act_order:
141 | x_child, y_child = x_par + self.dx[action], y_par + self.dy[action]
142 | child = x_child*self.n + y_child
143 | if self.map[x_child, y_child] == 0:
144 | continue
145 | if visited[child] == False:
146 | visited[child] = True
147 | queue.append(child)
148 | dist[child] = dist[par] + 1
149 | past[child] = par
150 |
151 | state_action_pair = []
152 | cur_state = goal_key
153 | while cur_state is not state_key:
154 | prev_state = past[cur_state]
155 | prev_action = self.inv_action(cur_state, prev_state)
156 | x_prev, y_prev = prev_state // self.n, prev_state % self.n
157 | print(x_prev, y_prev)
158 | state_action_pair.append(np.hstack([self.label2obs(x_prev, y_prev), np.array((prev_action, ))]))
159 | cur_state = prev_state
160 | state_action_pair.reverse()
161 | state_action_pair.append(np.hstack([self.label2obs(goal_key // self.n, goal_key % self.n), np.array((prev_action, ))]))
162 | print(len(state_action_pair))
163 | return dist, state_action_pair
164 |
165 | def get_pairwise(self, state, target):
166 | dist = self.bfs_dist(state, target)
167 | return dist
168 |
169 | def all_states(self):
170 | states = []
171 | mask = []
172 | for i in range(self.n):
173 | for j in range(self.n):
174 | self.x = i
175 | self.y = j
176 | states.append(self.get_obs())
177 | if isinstance(states[-1], dict):
178 | states[-1] = states[-1]['observation']
179 | mask.append(self.map[self.x, self.y] > 0.5)
180 | return np.array(states)[mask]
181 |
182 | def all_edges(self):
183 | A = np.zeros((self.n*self.n, self.n*self.n))
184 | mask = []
185 | for i in range(self.n):
186 | for j in range(self.n):
187 | mask.append(self.map[i, j] > 0.5)
188 | if self.map[i][j]:
189 | for a in range(4):
190 | self.x = i
191 | self.y = j
192 | t = self.step(a)[0]
193 | if isinstance(t, dict):
194 | t = t['observation']
195 | self.restore(t)
196 | A[i*self.n+j, self.x*self.n + self.y] = 1
197 | return A[mask][:, mask]
198 |
199 | def add_noise(self, start, goal, dist, alpha=0.1, order=False):
200 |
201 | if order:
202 | act_order = range(4)
203 | else:
204 | act_order = range(3, 0, -1)
205 |
206 | cur_state_id = start.argmax()
207 | goal_id = goal.argmax()
208 | new_seq = []
209 | while(cur_state_id != goal_id):
210 | x_cur, y_cur = cur_state_id // self.n, cur_state_id % self.n
211 | if np.random.randn() < alpha:
212 | cur_action = np.random.randint(4)
213 | nx, ny = x_cur+self.dx[cur_action], y_cur+self.dy[cur_action]
214 | new_seq.append(np.hstack([self.label2obs(x_cur, y_cur), np.array((cur_action, ))]))
215 | #print('state, action', (cur_state_id//self.n, cur_state_id%self.n), cur_action)
216 | if self.map[nx][ny] > 0.5:
217 | cur_state_id = nx*self.n + ny
218 | else:
219 | cur_state_id = cur_state_id
220 | else:
221 | dist_n = -np.inf
222 | cur_action = -1
223 | for action in act_order:
224 | x_n, y_n = x_cur + self.dx[action], y_cur + self.dy[action]
225 | if dist[x_n*self.n+y_n] > dist_n:
226 | dist_n = dist[x_n*self.n+y_n]
227 | cur_action = action
228 | elif dist[x_n*self.n+y_n] == dist_n:
229 | cur_action = np.random.choice(np.array([cur_action, action]))
230 |
231 | nx, ny = x_cur+self.dx[cur_action], y_cur+self.dy[cur_action]
232 | new_seq.append(np.hstack([self.label2obs(x_cur, y_cur), np.array((cur_action, ))]))
233 | #print('state, action', (cur_state_id//self.n, cur_state_id%self.n), cur_action)
234 | if self.map[nx][ny] > 0.5:
235 | cur_state_id = nx*self.n + ny
236 | else:
237 | cur_state_id = cur_state_id
238 |
239 | new_seq.append(np.hstack([self.label2obs(goal_id//self.n, goal_id%self.n), np.array((cur_action, ))]))
240 | return new_seq
241 |
242 | class FourRoom1(FourRoom):
243 | def __init__(self, seed=None, *args, **kwargs):
244 | FourRoom.__init__(self, *args, **kwargs)
245 | self.n = 11
246 | self.map = np.array([
247 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
248 | 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0,
249 | 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0,
250 | 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,
251 | 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0,
252 | 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0,
253 | 0, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0,
254 | 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0,
255 | 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,
256 | 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0,
257 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
258 | ]).reshape((self.n, self.n))
259 | self.init()
260 |
261 | def init(self):
262 | self.observation_space = spaces.Box(low=0, high=1, shape=(self.n*self.n,), dtype=np.float32)
263 | self.observation_space.n = self.n
264 | self.dx = [0, 1, 0, -1]
265 | self.dy = [2, 0, -2, 0]
266 | self.action_space = spaces.Discrete(len(self.dx))
267 | self.reset()
268 |
269 | def inv_action(self, state, prev_state):
270 | x, y = state // self.n, state % self.n
271 | px, py = prev_state // self.n, prev_state % self.n
272 | dx = x - px
273 | dy = y - py
274 | if dx == 1 and dy == 0:
275 | return 1
276 | elif dx == -1 and dy == 0:
277 | return 3
278 | elif dy == 2 and dx == 0:
279 | return 0
280 | else:
281 | return 2
282 |
283 | def step(self, action):
284 | #assert not self.done
285 | nx, ny = max(0, self.x + self.dx[action]), max(0, self.y + self.dy[action])
286 | nx, ny = min(self.n-1, nx), min(self.n-1, ny)
287 | info = {'is_success': False}
288 | #before = self.get_obs().argmax()
289 | if self.map[nx, ny]:
290 | self.x, self.y = nx, ny
291 | #dis = (self.goal[0]-self.x)**2 + (self.goal[1]-self.y)**2
292 | #reward = -np.sqrt(dis)
293 | reward = -1
294 | done = False
295 | else:
296 | #dis = (self.goal[0]-self.x)**2 + (self.goal[1]-self.y)**2
297 | #reward = -np.sqrt(dis)
298 | reward = -1
299 | done = False
300 | if nx == self.goal[0] and ny == self.goal[1]:
301 | reward = 0
302 | info = {'is_success': True}
303 | done = self.done = True
304 | return self.get_obs(), reward, done, info
305 |
--------------------------------------------------------------------------------
/agents/soft_bc_agent.py:
--------------------------------------------------------------------------------
1 | from models.ppo_models import Value, Policy, DiscretePolicy
2 | import pickle
3 | import numpy as np
4 | from utils import *
5 | from utils.utils import adjust_lr
6 | from core.common import estimate_advantages
7 | from core.agent import Agent
8 | import gym
9 | from models.WGAN import W_Discriminator
10 | from models.dynamics import ForwardModel, InverseModel
11 | from tensorboardX import SummaryWriter
12 | import datetime
13 | import torch.autograd as autograd
14 | from torch.distributions import Normal
15 | from utils.utils import soft_update, hard_update
16 | from models.VAE import VAE
17 |
18 |
19 | class SoftBC_agent(object):
20 |
21 | def __init__(self, args, running_state=None):
22 | """environment"""
23 | self.env = gym.make(args.env_name)
24 | self.dtype = torch.float32
25 | self.args = args
26 | torch.set_default_dtype(self.dtype)
27 | self.state_dim = self.env.observation_space.shape[0]
28 | self.is_disc_action = len(self.env.action_space.shape) == 0
29 | self.action_dim = 1 if self.is_disc_action else self.env.action_space.shape[0]
30 | self.device = torch.device('cuda', index=args.gpu_index) if torch.cuda.is_available() else torch.device('cpu')
31 |
32 | '''running normalizer'''
33 | if running_state is not None:
34 | self.running_state = running_state
35 | else:
36 | self.running_state = ZFilter((self.state_dim,), clip=5)
37 |
38 | """seeding"""
39 | np.random.seed(args.seed)
40 | torch.manual_seed(args.seed)
41 | self.env.seed(args.seed)
42 |
43 | """define actor and critic"""
44 | if self.is_disc_action:
45 | self.policy_net = DiscretePolicy(self.state_dim, self.env.action_space.n)
46 | else:
47 | self.policy_net = Policy(self.state_dim, self.env.action_space.shape[0], log_std=args.log_std)
48 | self.value_net = Value(self.state_dim)
49 | self.goal_model = VAE(self.state_dim, latent_dim=128)
50 | self.inverse_model = InverseModel(self.state_dim*2, self.action_dim)
51 | self.discrim_net = W_Discriminator(self.state_dim, hidden_size=256)
52 | self.max_action = self.env.action_space.high[0]
53 |
54 | self.value_iter = self.args.value_iter
55 | self.policy_iter = 0
56 |
57 | to_device(self.device, self.policy_net, self.value_net, self.goal_model, \
58 | self.inverse_model, self.discrim_net)
59 |
60 | self.optimizer_policy = torch.optim.Adam(self.policy_net.parameters(), lr=args.policy_lr)
61 | self.optimizer_value = torch.optim.Adam(self.value_net.parameters(), lr=args.value_lr)
62 | self.optimizer_vae = torch.optim.Adam(self.goal_model.parameters(), lr=args.model_lr)
63 | self.optimizer_discrim = torch.optim.Adam(self.discrim_net.parameters(), lr=args.gan_lr, betas=(args.beta1, 0.99))
64 | self.optimizer_inverse = torch.optim.Adam(self.inverse_model.parameters(), lr=args.model_lr)
65 |
66 | self.save_path = '{}_softbc_{}_{}'.format(datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"), args.env_name, \
67 | args.beta)
68 | self.writer =SummaryWriter(log_dir='runs/'+self.save_path)
69 |
70 | self.trained = False
71 | self.trajectory = OnlineMemory()
72 |
73 | self.agent = Agent(self.env, self.policy_net, self.device, mean_action=True, custom_reward=self.expert_reward,
74 | running_state=self.running_state, render=args.render, update=False)
75 |
76 | def expert_reward(self, state, action, next_state, done):
77 | coeff = 0.1 # to stablize
78 | state = torch.tensor(state, dtype=self.dtype).to(self.device)
79 | next_state = torch.tensor(next_state, dtype=self.dtype).to(self.device)
80 | if self.trained is False:
81 | return 0
82 | with torch.no_grad():
83 | reward = self.discrim_net(next_state[None, :]).item() - self.get_expert_mean()
84 | return reward*coeff
85 |
86 | def preprocess_running_state(self, expert_traj):
87 | expert_state_actions = expert_traj
88 | perm = np.arange(expert_state_actions.shape[0])
89 | np.random.shuffle(perm)
90 | expert_state_actions = expert_state_actions[perm].copy()
91 | expert_states = expert_state_actions[:, :self.state_dim]
92 |
93 | for i in range(expert_states.shape[0]):
94 | state = self.running_state(expert_states[i])
95 | # print(self.running_state.rs.n, self.running_state.rs.mean, self.running_state.rs.var)
96 | return self.running_state
97 |
98 | def pretrain_policy(self, epoches=150):
99 | state_tuples = torch.from_numpy(self.state_tuples).to(self.dtype).to(self.device)
100 | expert_states = state_tuples[:, :self.state_dim]
101 | expert_next_states = state_tuples[:, self.state_dim:self.state_dim*2]
102 | expert_actions = state_tuples[:, self.state_dim*2:]
103 |
104 | ''' during pretraining, we only optimize the mean, and keep the std of the policy to be a constant '''
105 | for i_epoch in range(epoches):
106 | idxs = np.arange(expert_states.shape[0])
107 | np.random.shuffle(idxs)
108 | num_batch = int(np.ceil(idxs.shape[-1] / self.args.optim_batch_size))
109 | for batch_num in range(num_batch):
110 | batch_idxs = idxs[batch_num * self.args.optim_batch_size : (batch_num + 1) * self.args.optim_batch_size]
111 | states = expert_states[batch_idxs].clone()
112 | mean_action, _, _ = self.policy_net.forward(states)
113 |
114 | next_states = self.goal_model.get_next_states(states)
115 | inverse_action = self.inverse_model.forward(states, next_states)
116 |
117 | policy_loss = ((inverse_action-mean_action)**2).mean()
118 |
119 | self.optimizer_policy.zero_grad()
120 | policy_loss.backward()
121 | self.optimizer_policy.step()
122 |
123 | #if (i_epoch+1) % 100 == 0:
124 | #adjust_lr(self.optimizer_policy, 2.)
125 | self.writer.add_scalar('loss/pretraining', policy_loss, i_epoch)
126 |
127 |
128 | def pretrain_policy_l2(self, epoches=80):
129 | expert_state_actions = torch.from_numpy(self.expert_traj).to(self.dtype).to(self.device)
130 | batch_num = int(math.ceil(expert_state_actions.shape[0] / 256))
131 | for i in range(epoches):
132 | perm = np.arange(expert_state_actions.shape[0])
133 | np.random.shuffle(perm)
134 | perm = LongTensor(perm).to(self.device)
135 | expert_state_actions = expert_state_actions[perm].clone()
136 | for b in range(batch_num):
137 | ind = slice(b * 256, min((b + 1) * 256, expert_state_actions.shape[0]))
138 | expert_sa_batch = expert_state_actions[ind]
139 | states = expert_sa_batch[:, :self.state_dim]
140 | actions = expert_sa_batch[:, self.state_dim:]
141 | policy_action, _, _ = self.policy_net.forward(states)
142 | self.optimizer_policy.zero_grad()
143 | loss = ((policy_action-actions)**2).mean(-1).mean(-1)
144 | loss.backward()
145 | self.optimizer_policy.step()
146 | if b == 0:
147 | self.writer.add_scalar('loss/pretrain_policy_l2', loss, i)
148 |
149 | def update_params(self, batch, i_iter, total_steps):
150 | states = torch.from_numpy(np.stack(batch.state)).to(self.dtype).to(self.device)
151 | actions = torch.from_numpy(np.stack(batch.action)).to(self.dtype).to(self.device)
152 | rewards = torch.from_numpy(np.stack(batch.reward)).to(self.dtype).to(self.device)
153 | masks = torch.from_numpy(np.stack(batch.mask)).to(self.dtype).to(self.device)
154 | next_states = torch.from_numpy(np.stack(batch.next_state)).to(self.dtype).to(self.device)
155 | with torch.no_grad():
156 | values = self.value_net(states)
157 | fixed_log_probs = self.policy_net.get_log_prob(states, actions)
158 |
159 | """get advantage estimation from the trajectories"""
160 | advantages, returns = estimate_advantages(rewards, masks, values, self.args.gamma, self.args.tau, self.device)
161 |
162 | """perform mini-batch PPO update"""
163 | optim_iter_num = int(math.ceil(states.shape[0] / self.args.optim_batch_size))
164 | for _ in range(self.args.optim_epochs):
165 | perm = np.arange(states.shape[0])
166 | np.random.shuffle(perm)
167 | perm = LongTensor(perm).to(self.device)
168 |
169 | states, actions, next_states, returns, advantages, fixed_log_probs = \
170 | states[perm].clone(), actions[perm].clone(), next_states[perm].clone(), returns[perm].clone(), advantages[perm].clone(), fixed_log_probs[perm].clone()
171 |
172 | if i_iter > self.args.freeze_policy_iter:
173 | self.policy_iter = self.args.policy_iter
174 |
175 | for i in range(optim_iter_num):
176 | ind = slice(i * self.args.optim_batch_size, min((i + 1) * self.args.optim_batch_size, states.shape[0]))
177 | states_b, actions_b, next_states_b, advantages_b, returns_b, fixed_log_probs_b = \
178 | states[ind], actions[ind], next_states[ind], advantages[ind], returns[ind], fixed_log_probs[ind]
179 |
180 | self.meta_ppo_step(states_b, actions_b, returns_b,
181 | advantages_b, fixed_log_probs_b, self.args.clip_epsilon, self.args.l2_reg, total_steps)
182 |
183 |
184 | def train(self):
185 | total_numsteps = 0
186 | for i_iter in range(self.args.max_iter_num):
187 | """generate multiple trajectories that reach the minimum batch_size"""
188 | to_device(torch.device('cpu'), self.policy_net, self.goal_model)
189 | batch, log = self.agent.collect_samples(self.args.min_batch_size)
190 | to_device(self.device, self.policy_net, self.goal_model)
191 |
192 | total_numsteps += log['num_steps']
193 | self.update_params(batch, i_iter, total_numsteps)
194 | self.train_models(batch, epoch=30)
195 | self.train_discrim(batch, epoch=30)
196 |
197 | if i_iter % self.args.log_interval == 0:
198 | print('{}\tT_sample {:.4f}\texpert_R_avg {:.2f}\tR_avg {:.2f}'.format(
199 | i_iter, log['sample_time'], log['avg_c_reward'], log['avg_reward']))
200 | self.writer.add_scalar('reward/env', log['avg_reward'], total_numsteps)
201 | self.writer.add_scalar('reward/fake', log['avg_c_reward'], total_numsteps)
202 |
203 | if self.args.save_model_interval > 0 and (i_iter+1) % self.args.save_model_interval == 0:
204 | self.save_model()
205 |
206 | """clean up gpu memory"""
207 | torch.cuda.empty_cache()
208 |
209 | def split_data(self, expert_traj, state_pairs, state_tuples):
210 | self.expert_traj = expert_traj
211 | self.state_pairs = state_pairs
212 | idxs = np.arange(state_tuples.shape[0])
213 | np.random.shuffle(idxs)
214 | train_idxes = idxs[:(state_tuples.shape[0]*19)//20]
215 | test_idxes = idxs[(state_tuples.shape[0]*19)//20:]
216 | self.state_tuples = state_tuples[train_idxes]
217 | self.test_state_tuples = state_tuples[test_idxes]
218 | print('split train and validation', self.state_tuples.shape, self.test_state_tuples.shape)
219 |
220 |
221 | def train_discrim(self, batch, epoch=30):
222 | self.trained = True
223 | expert_data = torch.Tensor(self.expert_traj).to(self.device)
224 | expert_data = expert_data[:, :self.state_dim]
225 | imitator_data = torch.from_numpy(np.stack(batch.state)).to(self.dtype).to(self.device)
226 | #print(imitator_data.shape)
227 | for _ in range(epoch):
228 | self.optimizer_discrim.zero_grad()
229 | self.discrim_net.zero_grad()
230 | D_real = self.discrim_net(expert_data)
231 | D_real = D_real.mean()
232 | D_fake = self.discrim_net(imitator_data)
233 | D_fake = D_fake.mean()
234 | loss = -D_real + D_fake
235 | loss.backward()
236 | # train with gradient penalty
237 | gradient_penalty = self.calc_gradient_penalty(expert_data.data, imitator_data.data)
238 | gradient_penalty.backward()
239 | self.wasserstein = D_real - D_fake
240 | #print('wasserstein distance', self.wasserstein.item())
241 | self.optimizer_discrim.step()
242 | print('final wasserstein distance', self.wasserstein.item())
243 | return self.wasserstein.item()
244 |
245 | def get_expert_mean(self):
246 | expert_data = torch.Tensor(self.expert_traj).to(self.device)
247 | expert_data = expert_data[:, :self.state_dim]
248 | expert_mean = self.discrim_net(expert_data).mean().item()
249 | return expert_mean
250 |
251 | def train_models(self, batch, epoch):
252 | states = torch.from_numpy(np.stack(batch.state)).to(self.dtype).to(self.device)
253 | actions = torch.from_numpy(np.stack(batch.action)).to(self.dtype).to(self.device)
254 | next_states = torch.from_numpy(np.stack(batch.next_state)).to(self.dtype).to(self.device)
255 | self.inverse_model.train(states, next_states, actions, self.optimizer_inverse, epoch=epoch, batch_size=self.args.optim_batch_size)
256 |
257 | def pretrain_vae(self, iter=200, epoch=2, lr_decay_rate=50):
258 | state_tuples = torch.from_numpy(self.state_tuples).to(self.dtype).to(self.device)
259 | s, t, action = state_tuples[:, :self.state_dim], state_tuples[:, self.state_dim:2*self.state_dim], \
260 | state_tuples[:, 2*self.state_dim:]
261 |
262 | state_tuples_test = torch.from_numpy(self.test_state_tuples).to(self.dtype).to(self.device)
263 | s_test, t_test, action_test = state_tuples_test[:, :self.state_dim], state_tuples_test[:, self.state_dim:2 * self.state_dim], \
264 | state_tuples_test[:, 2 * self.state_dim:]
265 |
266 |
267 | for i in range(1, iter + 1):
268 | loss = self.goal_model.train(s, t, epoch=epoch, optimizer=self.optimizer_vae, \
269 | batch_size=self.args.optim_batch_size, beta=self.args.beta)
270 | next_states = self.goal_model.get_next_states(s_test)
271 | val_error = ((t_test - next_states) ** 2).mean()
272 | self.writer.add_scalar('loss/vae', loss, i)
273 | self.writer.add_scalar('valid/vae', val_error, i)
274 |
275 | if i % lr_decay_rate == 0:
276 | adjust_lr(self.optimizer_vae, 2.)
277 |
278 | def pretrain_dynamics_with_l2(self, policy_epoch=50, iter=200, epoch=2, lr_decay_rate=50):
279 | '''
280 | designed for cross-morphology
281 | use l2 to pretrain policy
282 | collecting data with pretrained policy to train dynamics model
283 | '''
284 | self.pretrain_policy_l2(epoches=policy_epoch)
285 | memory_bc = self.warm_up(steps=50000, use_policy=True)
286 | state_tuples = torch.from_numpy(memory_bc).to(self.dtype).to(self.device)
287 | s, t, action = state_tuples[:, :self.state_dim], state_tuples[:, self.state_dim:2*self.state_dim], \
288 | state_tuples[:, 2*self.state_dim:]
289 |
290 | # train the inverse model
291 | for i in range(1, iter + 1):
292 | loss = self.inverse_model.train(s, t, action, self.optimizer_inverse, epoch=epoch,
293 | batch_size=self.args.optim_batch_size)
294 | self.writer.add_scalar('loss/inverse', loss, i)
295 |
296 | if i % lr_decay_rate == 0:
297 | adjust_lr(self.optimizer_inverse, 2.)
298 |
299 | def pretrain_dynamics_with_demo(self, iter=200, epoch=2, lr_decay_rate=50):
300 | '''
301 | designed for normal setting
302 | use demo to train the dynamics model
303 | '''
304 | state_tuples = torch.from_numpy(self.state_tuples).to(self.dtype).to(self.device)
305 | s, t, action = state_tuples[:, :self.state_dim], state_tuples[:, self.state_dim:2*self.state_dim], \
306 | state_tuples[:, 2*self.state_dim:]
307 | state_tuples_test = torch.from_numpy(self.test_state_tuples).to(self.dtype).to(self.device)
308 | s_test, t_test, action_test = state_tuples_test[:, :self.state_dim], state_tuples_test[:,
309 | self.state_dim:2 * self.state_dim], \
310 | state_tuples_test[:, 2 * self.state_dim:]
311 |
312 | # train the inverse model
313 | for i in range(1, iter + 1):
314 | loss = self.inverse_model.train(s, t, action, self.optimizer_inverse, epoch=epoch, batch_size=self.args.optim_batch_size)
315 | pred_acion = self.inverse_model.forward(s_test, t_test)
316 | val_error = ((pred_acion - action_test) ** 2).mean()
317 | self.writer.add_scalar('loss/inverse', loss, i)
318 | self.writer.add_scalar('valid/inverse', val_error, i)
319 | if i % lr_decay_rate == 0:
320 | adjust_lr(self.optimizer_inverse, 2.)
321 |
322 |
323 | def meta_ppo_step(self, states, actions, returns, advantages, fixed_log_probs, clip_epsilon, l2_reg, total_step, lam1=0.001, lam2=0.01):
324 | """update critic"""
325 | for _ in range(self.value_iter):
326 | values_pred = self.value_net(states)
327 | value_loss = (values_pred - returns).pow(2).mean()
328 | # weight decay
329 | for param in self.value_net.parameters():
330 | value_loss += param.pow(2).sum() * l2_reg
331 | self.optimizer_value.zero_grad()
332 | value_loss.backward()
333 | torch.nn.utils.clip_grad_norm_(self.value_net.parameters(), 5)
334 | self.optimizer_value.step()
335 | self.writer.add_scalar('loss/value_loss', value_loss, total_step)
336 |
337 | for _ in range(self.policy_iter):
338 | """update policy"""
339 | log_probs = self.policy_net.get_log_prob(states, actions)
340 | ratio = torch.exp(log_probs - fixed_log_probs)
341 | surr1 = ratio * advantages
342 | surr2 = torch.clamp(ratio, 1.0 - clip_epsilon, 1.0 + clip_epsilon) * advantages
343 | policy_surr = -torch.min(surr1, surr2).mean()
344 | '''compute kl'''
345 | mean_action, action_log_std, action_std = self.policy_net.forward(states)
346 | next_states = self.goal_model.get_next_states(states)
347 | inverse_action = self.inverse_model.forward(states, next_states)
348 |
349 | # lam2 is a hyper-parameter, positive relevant to sigma^(-2) of action prior
350 | policy_loss = -lam2 * action_log_std.sum() + (action_std**2).sum() + ((inverse_action-mean_action)**2).mean()
351 | policy_loss += lam1 * policy_surr # balance the policy surrogate
352 |
353 | self.optimizer_policy.zero_grad()
354 | policy_loss.backward()
355 | torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 5)
356 | self.optimizer_policy.step()
357 | self.writer.add_scalar('loss/policy_loss', policy_loss, total_step)
358 |
359 | def warm_up(self, steps=5000, use_policy=False):
360 | memory = []
361 | t = 0
362 | if use_policy is True:
363 | to_device(torch.device('cpu'), self.policy_net)
364 |
365 | while t < steps:
366 | obs = self.env.reset()
367 | if self.running_state is not None:
368 | obs = self.running_state(obs, update=False)
369 | while True:
370 | t += 1
371 | if use_policy is True:
372 | state_var = tensor(obs).to(self.dtype).unsqueeze(0)
373 | action = self.policy_net.select_action(state_var)[0].numpy()
374 | else:
375 | action = self.env.action_space.sample()
376 |
377 | next_obs, reward, done, _ = self.env.step(action)
378 | if self.running_state is not None:
379 | next_obs = self.running_state(next_obs, update=False)
380 | transition = np.concatenate([obs, next_obs, action])
381 | memory.append(transition)
382 | if done:
383 | break
384 | obs = next_obs
385 | memory = np.stack(memory)
386 | if use_policy is True:
387 | to_device(self.device, self.policy_net)
388 | return memory
389 |
390 | def eval(self, num=10):
391 | t = 0
392 | accu_reward = 0
393 | to_device(torch.device('cpu'), self.policy_net)
394 | while t < num:
395 | obs = self.env.reset()
396 | if self.running_state is not None:
397 | obs = self.running_state(obs, update=False)
398 | while True:
399 | state_var = tensor(obs).to(self.dtype).unsqueeze(0)
400 | action = self.policy_net(state_var)[0][0].detach().numpy()
401 | next_obs, reward, done, _ = self.env.step(action)
402 | accu_reward += reward
403 | if self.running_state is not None:
404 | next_obs = self.running_state(next_obs, update=False)
405 | if done:
406 | t += 1
407 | break
408 | obs = next_obs
409 | to_device(self.device, self.policy_net)
410 | print('accumulated reward', accu_reward/num)
411 | return accu_reward/num
412 |
413 | def calc_gradient_penalty(self, real_data, fake_data):
414 | alpha = torch.rand(fake_data.shape[0], 1)
415 | idx = np.random.randint(0, len(real_data), fake_data.shape[0])
416 | real_data_b = real_data[idx]
417 |
418 | alpha = alpha.expand(real_data_b.size())
419 | alpha = alpha.to(self.device)
420 |
421 | interpolates = alpha * real_data_b + ((1 - alpha) * fake_data)
422 | interpolates = interpolates.to(self.device)
423 | interpolates = autograd.Variable(interpolates, requires_grad=True)
424 |
425 | disc_interpolates = self.discrim_net(interpolates)
426 |
427 | gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
428 | grad_outputs=torch.ones(disc_interpolates.size()).to(self.device),
429 | create_graph=True, retain_graph=True, only_inputs=True)[0]
430 |
431 | gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * self.args.lam
432 | return gradient_penalty
433 |
434 | def save_model(self):
435 | to_device(torch.device('cpu'), self.inverse_model, self.policy_net, self.value_net, self.goal_model)
436 | print('saving models')
437 | torch.save([self.inverse_model.state_dict(), self.policy_net.state_dict(), self.value_net.state_dict(),
438 | self.goal_model.state_dict()], assets_dir()+'/learned_models/{}_{}_models.pt'.format(self.args.env_name, str(self.args.beta)))
439 | to_device(self.device, self.inverse_model, self.policy_net, self.value_net, self.goal_model)
440 |
441 | def load_model(self):
442 | model_path = assets_dir()+'/learned_models/{}_{}_models.pt'.format(self.args.env_name, str(self.args.beta))
443 | print('load model from', model_path)
444 | to_device(torch.device('cpu'), self.inverse_model, self.policy_net, self.value_net, self.goal_model)
445 | pretrained_dict = torch.load(model_path)
446 | self.inverse_model.load_state_dict(pretrained_dict[0])
447 | self.policy_net.load_state_dict(pretrained_dict[1])
448 | self.value_net.load_state_dict(pretrained_dict[2])
449 | self.goal_model.load_state_dict(pretrained_dict[3])
450 | to_device(self.device, self.inverse_model, self.policy_net, self.value_net, self.goal_model)
451 |
452 | def save_replay(self):
453 | from utils.render import play
454 | to_device(torch.device('cpu'), self.policy_net)
455 | video_path = self.save_path+'.avi'
456 | play(self.env, self.policy_net, self.running_state, video_path=video_path, time_limit=1000, device='cpu')
457 | to_device(self.device, self.policy_net)
458 |
--------------------------------------------------------------------------------