├── expert ├── assets │ └── expert_traj │ │ ├── Hopper-v2_expert_traj.p │ │ └── Swimmer-v2_expert_traj.p ├── sac_args.py ├── save_traj_ppo.py ├── save_traj_sac.py ├── train_sac.py └── train_ppo.py ├── utils ├── __init__.py ├── tools.py ├── math.py ├── zfilter.py ├── replay_memory.py ├── render.py ├── torch.py └── utils.py ├── envs ├── test_envs.py ├── mujoco │ ├── __init__.py │ ├── assets │ │ ├── light_swimmer.xml │ │ ├── swimmer.xml │ │ ├── heavy_swimmer.xml │ │ ├── disabled_swimmer.xml │ │ ├── ant.xml │ │ ├── light_ant.xml │ │ ├── heavy_ant.xml │ │ └── disabled_ant.xml │ ├── swimmer.py │ └── ant.py ├── __init__.py ├── cycle_4room.py ├── mountaincar.py └── fourroom.py ├── core ├── common.py ├── ppo.py └── agent.py ├── LICENSE ├── models ├── rnd_model.py ├── WGAN.py ├── VAE.py ├── sac_models.py ├── ppo_models.py └── dynamics.py ├── .gitignore ├── README.md ├── sail.py └── agents ├── sac_agent.py └── soft_bc_agent.py /expert/assets/expert_traj/Hopper-v2_expert_traj.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FangchenLiu/SAIL/HEAD/expert/assets/expert_traj/Hopper-v2_expert_traj.p -------------------------------------------------------------------------------- /expert/assets/expert_traj/Swimmer-v2_expert_traj.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FangchenLiu/SAIL/HEAD/expert/assets/expert_traj/Swimmer-v2_expert_traj.p -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from utils.replay_memory import * 2 | from utils.zfilter import * 3 | from utils.torch import * 4 | from utils.math import * 5 | from utils.tools import * 6 | -------------------------------------------------------------------------------- /utils/tools.py: -------------------------------------------------------------------------------- 1 | from os import path 2 | import torch 3 | 4 | def assets_dir(): 5 | return path.abspath(path.join(path.dirname(path.abspath(__file__)), '../expert/assets')) 6 | 7 | def swish(x): 8 | return x * torch.sigmoid(x) 9 | -------------------------------------------------------------------------------- /utils/math.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | 4 | 5 | def normal_entropy(std): 6 | var = std.pow(2) 7 | entropy = 0.5 + 0.5 * torch.log(2 * var * math.pi) 8 | return entropy.sum(1, keepdim=True) 9 | 10 | 11 | def normal_log_density(x, mean, log_std, std): 12 | var = std.pow(2) 13 | log_density = -(x - mean).pow(2) / (2 * var) - 0.5 * math.log(2 * math.pi) - log_std 14 | return log_density.sum(1, keepdim=True) 15 | -------------------------------------------------------------------------------- /envs/test_envs.py: -------------------------------------------------------------------------------- 1 | from utils.render import play 2 | import gym 3 | import envs 4 | import envs.mujoco 5 | 6 | if __name__ == '__main__': 7 | env = gym.make('DisableAnt-v0') 8 | play(env, None, None, video_path='disable_ant.avi', time_limit=200, device='cpu') 9 | env = gym.make('LightAnt-v0') 10 | play(env, None, None, video_path='light_ant.avi', time_limit=200, device='cpu') 11 | env = gym.make('HeavyAnt-v0') 12 | play(env, None, None, video_path='heavy_ant.avi', time_limit=200, device='cpu') -------------------------------------------------------------------------------- /core/common.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils import to_device 3 | 4 | 5 | def estimate_advantages(rewards, masks, values, gamma, tau, device): 6 | rewards, masks, values = to_device(torch.device('cpu'), rewards, masks, values) 7 | tensor_type = type(rewards) 8 | deltas = tensor_type(rewards.size(0), 1) 9 | advantages = tensor_type(rewards.size(0), 1) 10 | 11 | prev_value = 0 12 | prev_advantage = 0 13 | for i in reversed(range(rewards.size(0))): 14 | deltas[i] = rewards[i] + gamma * prev_value * masks[i] - values[i] 15 | advantages[i] = deltas[i] + gamma * tau * prev_advantage * masks[i] 16 | 17 | prev_value = values[i, 0] 18 | prev_advantage = advantages[i, 0] 19 | 20 | returns = values + advantages 21 | advantages = (advantages - advantages.mean()) / advantages.std() 22 | 23 | advantages, returns = to_device(device, advantages, returns) 24 | return advantages, returns -------------------------------------------------------------------------------- /envs/mujoco/__init__.py: -------------------------------------------------------------------------------- 1 | from gym.envs.registration import register 2 | 3 | register( 4 | id='DisableAnt-v0', 5 | entry_point='envs.mujoco.ant:DisableAntEnv', 6 | max_episode_steps=1000, 7 | reward_threshold=6000.0, 8 | ) 9 | 10 | register( 11 | id='HeavyAnt-v0', 12 | entry_point='envs.mujoco.ant:HeavyAntEnv', 13 | max_episode_steps=1000, 14 | reward_threshold=6000.0, 15 | ) 16 | 17 | register( 18 | id='LightAnt-v0', 19 | entry_point='envs.mujoco.ant:LightAntEnv', 20 | max_episode_steps=1000, 21 | reward_threshold=6000.0, 22 | ) 23 | 24 | register( 25 | id='DisableSwimmer-v0', 26 | entry_point='envs.mujoco.swimmer:DisableSwimmerEnv', 27 | max_episode_steps=1000, 28 | reward_threshold=360.0, 29 | ) 30 | register( 31 | id='LightSwimmer-v0', 32 | entry_point='envs.mujoco.swimmer:LightSwimmerEnv', 33 | max_episode_steps=1000, 34 | reward_threshold=360.0, 35 | ) 36 | register( 37 | id='HeavySwimmer-v0', 38 | entry_point='envs.mujoco.swimmer:HeavySwimmerEnv', 39 | max_episode_steps=1000, 40 | reward_threshold=360.0, 41 | ) -------------------------------------------------------------------------------- /envs/__init__.py: -------------------------------------------------------------------------------- 1 | from gym.envs.registration import register 2 | register( 3 | id='Fourroom-v0', 4 | entry_point='envs.fourroom:FourRoom', 5 | kwargs={'goal_type':'fix_goal'}, 6 | max_episode_steps=200, 7 | reward_threshold=100.0, 8 | nondeterministic=False, 9 | ) 10 | register( 11 | id='Fourroom-v1', 12 | entry_point='envs.fourroom:FourRoom1', 13 | kwargs={'goal_type':'fix_goal'}, 14 | max_episode_steps=200, 15 | reward_threshold=100.0, 16 | nondeterministic=False, 17 | ) 18 | register( 19 | id='CycleFourroom-v0', 20 | entry_point='envs.cycle_4room:FourRoom', 21 | max_episode_steps=200, 22 | reward_threshold=100.0, 23 | nondeterministic=False, 24 | ) 25 | register( 26 | id='CycleFourroom-v1', 27 | entry_point='envs.cycle_4room:FourRoom1', 28 | max_episode_steps=200, 29 | reward_threshold=100.0, 30 | nondeterministic=False, 31 | ) 32 | register( 33 | id='MyMountainCar-v0', 34 | entry_point='envs.mountaincar:Continuous_MountainCarEnv', 35 | max_episode_steps=999, 36 | reward_threshold=90.0, 37 | nondeterministic=False, 38 | ) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 FangchenLiu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /core/ppo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def ppo_step(policy_net, value_net, optimizer_policy, optimizer_value, optim_value_iternum, states, actions, 5 | returns, advantages, fixed_log_probs, clip_epsilon, l2_reg): 6 | 7 | """update critic""" 8 | for _ in range(optim_value_iternum): 9 | values_pred = value_net(states) 10 | value_loss = (values_pred - returns).pow(2).mean() 11 | # weight decay 12 | for param in value_net.parameters(): 13 | value_loss += param.pow(2).sum() * l2_reg 14 | optimizer_value.zero_grad() 15 | value_loss.backward() 16 | optimizer_value.step() 17 | 18 | """update policy""" 19 | log_probs = policy_net.get_log_prob(states, actions) 20 | ratio = torch.exp(log_probs - fixed_log_probs) 21 | surr1 = ratio * advantages 22 | surr2 = torch.clamp(ratio, 1.0 - clip_epsilon, 1.0 + clip_epsilon) * advantages 23 | policy_surr = -torch.min(surr1, surr2).mean() 24 | optimizer_policy.zero_grad() 25 | policy_surr.backward() 26 | torch.nn.utils.clip_grad_norm_(policy_net.parameters(), 40) 27 | optimizer_policy.step() 28 | -------------------------------------------------------------------------------- /models/rnd_model.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | from torch.nn import init 4 | import numpy as np 5 | 6 | class Coskx(nn.Module): 7 | def __init__(self, k=50): 8 | super(Coskx, self).__init__() 9 | self.k = k 10 | def forward(self, input): 11 | return torch.cos(input * self.k) 12 | 13 | class RND(nn.Module): 14 | def __init__(self, d, last_size=512, ptb=10): 15 | super(RND, self).__init__() 16 | self.target = nn.Sequential( 17 | nn.Linear(d, 512), 18 | nn.LeakyReLU(), 19 | nn.Linear(512, 512), 20 | nn.LeakyReLU(), 21 | Coskx(100), 22 | nn.Linear(512, last_size) 23 | ) 24 | self.predictor = nn.Sequential( 25 | nn.Linear(d, 512), 26 | nn.LeakyReLU(), 27 | nn.Linear(512, 512), 28 | nn.LeakyReLU(), 29 | nn.Linear(512, 512), 30 | nn.LeakyReLU(), 31 | nn.Linear(512, last_size) 32 | ) 33 | for p in self.modules(): 34 | if isinstance(p, nn.Linear): 35 | init.orthogonal_(p.weight, np.sqrt(2)) 36 | 37 | for param in self.target.parameters(): 38 | param.requires_grad = False 39 | 40 | def forward(self, states): 41 | target_feature = self.target(states).detach() 42 | predict_feature = self.predictor(states) 43 | return ((predict_feature - target_feature) ** 2).mean(-1) 44 | 45 | def get_q(self, states): 46 | err = self.forward(states) 47 | return torch.exp(-10*err).detach() 48 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by .ignore support plugin (hsz.mobi) 2 | ### JetBrains template 3 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm 4 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 5 | 6 | # User-specific stuff: 7 | .idea/**/workspace.xml 8 | .idea/**/tasks.xml 9 | .idea/dictionaries 10 | ./gail/load_expert_traj.py 11 | # Sensitive or high-churn files: 12 | .idea/**/dataSources/ 13 | .idea/**/dataSources.ids 14 | .idea/**/dataSources.xml 15 | .idea/**/dataSources.local.xml 16 | .idea/**/sqlDataSources.xml 17 | .idea/**/dynamic.xml 18 | .idea/**/uiDesigner.xml 19 | 20 | # Gradle: 21 | .idea/**/gradle.xml 22 | .idea/**/libraries 23 | 24 | # CMake 25 | cmake-build-debug/ 26 | cmake-build-release/ 27 | 28 | # Mongo Explorer plugin: 29 | .idea/**/mongoSettings.xml 30 | 31 | ## File-based project format: 32 | *.iws 33 | 34 | ## Plugin-specific files: 35 | 36 | # IntelliJ 37 | out/ 38 | runs/ 39 | 40 | # mpeltonen/sbt-idea plugin 41 | .idea_modules/ 42 | 43 | # JIRA plugin 44 | atlassian-ide-plugin.xml 45 | 46 | # Cursive Clojure plugin 47 | .idea/replstate.xml 48 | 49 | # Crashlytics plugin (for Android Studio and IntelliJ) 50 | com_crashlytics_export_strings.xml 51 | crashlytics.properties 52 | crashlytics-build.properties 53 | fabric.properties 54 | 55 | .idea/handful-of-trials-minimal-cartpole.iml 56 | .idea/inspectionProfiles/ 57 | .idea/markdown-navigator/ 58 | .idea/misc.xml 59 | .idea/modules.xml 60 | .idea/workspace.xml 61 | .idea/ 62 | __pycache__/ 63 | learned_models/ 64 | log/ 65 | *.lprof 66 | .DS_Store 67 | .pt 68 | .p 69 | *expert_traj.p 70 | -------------------------------------------------------------------------------- /utils/zfilter.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | # from https://github.com/joschu/modular_rl 4 | # http://www.johndcook.com/blog/standard_deviation/ 5 | 6 | 7 | class RunningStat(object): 8 | def __init__(self, shape): 9 | self._n = 0 10 | self._M = np.zeros(shape) 11 | self._S = np.zeros(shape) 12 | 13 | def push(self, x): 14 | x = np.asarray(x) 15 | assert x.shape == self._M.shape 16 | self._n += 1 17 | if self._n == 1: 18 | self._M[...] = x 19 | else: 20 | oldM = self._M.copy() 21 | self._M[...] = oldM + (x - oldM) / self._n 22 | self._S[...] = self._S + (x - oldM) * (x - self._M) 23 | 24 | @property 25 | def n(self): 26 | return self._n 27 | 28 | @property 29 | def mean(self): 30 | return self._M 31 | 32 | @property 33 | def var(self): 34 | return self._S / (self._n - 1) if self._n > 1 else np.square(self._M) 35 | 36 | @property 37 | def std(self): 38 | return np.sqrt(self.var) 39 | 40 | @property 41 | def shape(self): 42 | return self._M.shape 43 | 44 | 45 | class ZFilter: 46 | """ 47 | y = (x-mean)/std 48 | using running estimates of mean,std 49 | """ 50 | 51 | def __init__(self, shape, demean=True, destd=True, clip=10.0): 52 | self.demean = demean 53 | self.destd = destd 54 | self.clip = clip 55 | self.rs = RunningStat(shape) 56 | 57 | def __call__(self, x, update=True): 58 | if update: 59 | self.rs.push(x) 60 | if self.demean: 61 | x = x - self.rs.mean 62 | if self.destd: 63 | x = x / (self.rs.std + 1e-8) 64 | if self.clip: 65 | x = np.clip(x, -self.clip, self.clip) 66 | return x 67 | -------------------------------------------------------------------------------- /utils/replay_memory.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import random 3 | import numpy as np 4 | 5 | 6 | Transition = namedtuple('Transition', ('state', 'action', 'reward', 'next_state', 7 | 'mask')) 8 | 9 | class Memory: 10 | def __init__(self, capacity): 11 | self.capacity = int(capacity) 12 | self.buffer = [] 13 | self.position = 0 14 | 15 | def push(self, state, action, reward, next_state, done): 16 | if len(self.buffer) < self.capacity: 17 | self.buffer.append(None) 18 | self.buffer[self.position] = (state, action, reward, next_state, done) 19 | self.position = (self.position + 1) % self.capacity 20 | 21 | def sample(self, batch_size): 22 | batch = random.sample(self.buffer, batch_size) 23 | #print(batch) 24 | state, action, reward, next_state, done = map(np.stack, zip(*batch)) 25 | return state, action, reward, next_state, done 26 | 27 | def sample_all(self): 28 | batch = self.buffer 29 | state, action, reward, next_state, done = map(np.stack, zip(*batch)) 30 | return state, action, reward, next_state, done 31 | 32 | def __len__(self): 33 | return len(self.buffer) 34 | 35 | class OnlineMemory(object): 36 | def __init__(self): 37 | self.memory = [] 38 | 39 | def push(self, *args): 40 | """Saves a transition.""" 41 | self.memory.append(Transition(*args)) 42 | 43 | def sample(self, batch_size=None): 44 | if batch_size is None: 45 | return Transition(*zip(*self.memory)) 46 | else: 47 | random_batch = random.sample(self.memory, batch_size) 48 | return Transition(*zip(*random_batch)) 49 | 50 | def append(self, new_memory): 51 | self.memory += new_memory.memory 52 | 53 | def __len__(self): 54 | return len(self.memory) 55 | 56 | def clear(self): 57 | self.memory = [] 58 | 59 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # State Alignment-based Imitation Learning 2 | We propose a state-based imitation learning method for cross-morphology imitation learning, by considering both the state visitation distribution and local transition alignment. 3 | 4 | ### ATTENTION: The code is in a pre-release stage and under maintenance. Some details may be different from the paper or not included, due to internal integration for other projects. The author will reorganize some parts to match the original paper asap. 5 | 6 | ## Train 7 | The demonstration is provided in ``./expert/assets/expert_traj``. They are obtained from well-trained expert algorithms (SAC or TRPO). 8 | 9 | The cross-morphology imitator can be founded in ``./envs/mujoco/assets``. You can also customize your own environment for training. 10 | 11 | The main algorithm contains two settings: original imitation learning and cross-morphology imitation learning. 12 | The only difference is the pretraining stage for the inverse dynamics model. 13 | 14 | The format should be 15 | ```bash 16 | python sail.py --env-name [YOUR-ENV-NAME] --expert-traj-path [PATH-TO-DEMO] --beta 0.01 --resume [if want resume] --transfer [if cross morphology] 17 | ``` 18 | For example, for original hopper imitation: 19 | ```bash 20 | python sail.py --env-name Hopper-v2 --expert-traj-path ./expert/assets/expert_traj/Hopper-v2_expert_traj.p --beta 0.005 21 | ``` 22 | for disabled swimmer imitation 23 | ```bash 24 | python sail.py --env-name DisableSwimmer-v0 --expert-traj-path ./expert/assets/expert_traj/Swimmer-v2_expert_traj.p --beta 0.005 --transfer 25 | ``` 26 | 27 | ## Cite Our Paper 28 | If you find it useful, please consider to cite our paper. 29 | ``` 30 | @article{liu2019state, 31 | title={State Alignment-based Imitation Learning}, 32 | author={Liu, Fangchen and Ling, Zhan and Mu, Tongzhou and Su, Hao}, 33 | journal={arXiv preprint arXiv:1911.10947}, 34 | year={2019} 35 | } 36 | ``` 37 | 38 | ## Demonstrations 39 | Please download [here](https://drive.google.com/open?id=1cIqYevPDE2_06Elyo_UqEfHiQvOKP6in) and put it to ```./expert/assets/expert_traj``` 40 | -------------------------------------------------------------------------------- /utils/render.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch 3 | import numpy as np 4 | 5 | def play(env, policy, running_state=None, video_path="tmp.avi", time_limit=999, device='cpu'): 6 | out = None 7 | obs = env.reset() 8 | if running_state is not None: 9 | obs = running_state(obs, update=False) 10 | num = 0 11 | 12 | while True: 13 | img = env.unwrapped.render(mode='rgb_array')[:, :, ::-1].copy() 14 | if out is None: 15 | out = cv2.VideoWriter( 16 | video_path, cv2.VideoWriter_fourcc(*'XVID'), 20.0, (img.shape[1], img.shape[0])) 17 | out.write(img) 18 | if policy is not None: 19 | obs = torch.tensor(obs).float().unsqueeze(0).to(device) 20 | action = policy.select_action(obs)[0].detach().cpu().numpy() 21 | action = int(action) if policy.is_disc_action else action.astype(np.float32) 22 | else: 23 | action = env.action_space.sample() 24 | obs, rew, done, info = env.step(action) 25 | if running_state is not None: 26 | obs = running_state(obs, update=False) 27 | if done: 28 | obs = env.reset() 29 | num += 1 30 | #assert not info['is_success'] 31 | flag = True 32 | if not flag: 33 | print(num, info, rew, done, env.goal, action) 34 | if num == time_limit - 1: 35 | break 36 | env.close() 37 | 38 | def play_action_seq(env, action_seq, video_path="tmp.avi", time_limit=999): 39 | out = None 40 | obs = env.reset() 41 | t = 0 42 | reward = 0 43 | while True: 44 | img = env.unwrapped.render(mode='rgb_array')[:, :, ::-1].copy() 45 | if out is None: 46 | out = cv2.VideoWriter( 47 | video_path, cv2.VideoWriter_fourcc(*'XVID'), 20.0, (img.shape[1], img.shape[0])) 48 | out.write(img) 49 | action = action_seq[t] 50 | obs, rew, done, info = env.step(action) 51 | reward+=rew 52 | t+=1 53 | if t > time_limit: 54 | print('accumulated reward', reward) 55 | break 56 | env.close() -------------------------------------------------------------------------------- /utils/torch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from models.dynamics import ForwardModel, InverseModel 4 | 5 | tensor = torch.tensor 6 | DoubleTensor = torch.DoubleTensor 7 | FloatTensor = torch.FloatTensor 8 | LongTensor = torch.LongTensor 9 | ByteTensor = torch.ByteTensor 10 | ones = torch.ones 11 | zeros = torch.zeros 12 | 13 | 14 | def to_device(device, *args): 15 | for x in args: 16 | if isinstance(x, ForwardModel) or isinstance(x, InverseModel): 17 | x.device = device 18 | return [x.to(device) for x in args] 19 | 20 | 21 | def get_flat_params_from(model): 22 | params = [] 23 | for param in model.parameters(): 24 | params.append(param.view(-1)) 25 | 26 | flat_params = torch.cat(params) 27 | return flat_params 28 | 29 | 30 | def set_flat_params_to(model, flat_params): 31 | prev_ind = 0 32 | for param in model.parameters(): 33 | flat_size = int(np.prod(list(param.size()))) 34 | param.data.copy_( 35 | flat_params[prev_ind:prev_ind + flat_size].view(param.size())) 36 | prev_ind += flat_size 37 | 38 | 39 | def get_flat_grad_from(inputs, grad_grad=False): 40 | grads = [] 41 | for param in inputs: 42 | if grad_grad: 43 | grads.append(param.grad.grad.view(-1)) 44 | else: 45 | if param.grad is None: 46 | grads.append(zeros(param.view(-1).shape)) 47 | else: 48 | grads.append(param.grad.view(-1)) 49 | 50 | flat_grad = torch.cat(grads) 51 | return flat_grad 52 | 53 | 54 | def compute_flat_grad(output, inputs, filter_input_ids=set(), retain_graph=False, create_graph=False): 55 | if create_graph: 56 | retain_graph = True 57 | 58 | inputs = list(inputs) 59 | params = [] 60 | for i, param in enumerate(inputs): 61 | if i not in filter_input_ids: 62 | params.append(param) 63 | 64 | grads = torch.autograd.grad(output, params, retain_graph=retain_graph, create_graph=create_graph) 65 | 66 | j = 0 67 | out_grads = [] 68 | for i, param in enumerate(inputs): 69 | if i in filter_input_ids: 70 | out_grads.append(zeros(param.view(-1).shape, device=param.device, dtype=param.dtype)) 71 | else: 72 | out_grads.append(grads[j].view(-1)) 73 | j += 1 74 | grads = torch.cat(out_grads) 75 | 76 | for param in params: 77 | param.grad = None 78 | return grads 79 | -------------------------------------------------------------------------------- /envs/mujoco/assets/light_swimmer.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /envs/mujoco/assets/swimmer.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /envs/mujoco/assets/heavy_swimmer.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /envs/mujoco/assets/disabled_swimmer.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /models/WGAN.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | class Generator(nn.Module): 5 | 6 | def __init__(self, input_dim, output_dim, hidden_size=400): 7 | super(Generator, self).__init__() 8 | main = nn.Sequential( 9 | nn.Linear(input_dim, hidden_size), 10 | nn.ReLU(), 11 | nn.Linear(hidden_size, hidden_size), 12 | nn.ReLU(), 13 | nn.Linear(hidden_size, hidden_size), 14 | nn.ReLU(), 15 | nn.Linear(hidden_size, output_dim), 16 | ) 17 | self.main = main 18 | 19 | def forward(self, noise): 20 | output = self.main(noise) 21 | return output 22 | 23 | class Discriminator(nn.Module): 24 | def __init__(self, num_inputs, layers=3, hidden_size=400, activation='tanh'): 25 | super(Discriminator, self).__init__() 26 | if activation == 'tanh': 27 | self.activation = torch.tanh 28 | elif activation == 'relu': 29 | self.activation = torch.relu 30 | elif activation == 'sigmoid': 31 | self.activation = torch.sigmoid 32 | self.num_layers = layers 33 | self.affine_layers = nn.ModuleList() 34 | 35 | for i in range(self.num_layers): 36 | if i == 0: 37 | self.affine_layers.append(nn.Linear(num_inputs, hidden_size)) 38 | else: 39 | self.affine_layers.append(nn.Linear(hidden_size, hidden_size)) 40 | 41 | self.logic = nn.Linear(hidden_size, 1) 42 | self.logic.weight.data.mul_(0.1) 43 | self.logic.bias.data.mul_(0.0) 44 | 45 | def forward(self, x): 46 | for affine in self.affine_layers: 47 | x = self.activation(affine(x)) 48 | prob = torch.sigmoid(self.logic(x)) 49 | return prob 50 | 51 | #for wgan: remove sigmoid in the last layer of discriminator 52 | class W_Discriminator(nn.Module): 53 | 54 | def __init__(self, input_dim, hidden_size=400, layers=3): 55 | super(W_Discriminator, self).__init__() 56 | self.activation = torch.relu 57 | self.num_layers = layers 58 | self.affine_layers = nn.ModuleList() 59 | 60 | for i in range(self.num_layers): 61 | if i == 0: 62 | self.affine_layers.append(nn.Linear(input_dim, hidden_size)) 63 | else: 64 | self.affine_layers.append(nn.Linear(hidden_size, hidden_size)) 65 | 66 | self.final = nn.Linear(hidden_size, 1) 67 | self.final.weight.data.mul_(0.1) 68 | self.final.bias.data.mul_(0.0) 69 | 70 | def forward(self, x): 71 | for affine in self.affine_layers: 72 | x = self.activation(affine(x)) 73 | res = self.final(x) 74 | return res.view(-1) 75 | 76 | def weights_init(m): 77 | classname = m.__class__.__name__ 78 | if classname.find('Linear') != -1: 79 | m.weight.data.normal_(0.0, 0.02) 80 | m.bias.data.fill_(0) 81 | elif classname.find('BatchNorm') != -1: 82 | m.weight.data.normal_(1.0, 0.02) 83 | m.bias.data.fill_(0) -------------------------------------------------------------------------------- /expert/sac_args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def get_sac_args(): 4 | parser = argparse.ArgumentParser(description='PyTorch GAIL example') 5 | parser.add_argument('--env-name', default="Hopper-v2", 6 | help='name of the environment to run') 7 | parser.add_argument('--policy', default="Gaussian", 8 | help='algorithm to use: Gaussian | Deterministic') 9 | parser.add_argument('--eval', type=bool, default=True, 10 | help='Evaluates a policy a policy every 10 episode (default:True)') 11 | parser.add_argument('--gamma', type=float, default=0.99, metavar='G', 12 | help='discount factor for reward (default: 0.99)') 13 | parser.add_argument('--tau', type=float, default=0.005, metavar='G', 14 | help='target smoothing coefficient(τ) (default: 0.005)') 15 | parser.add_argument('--lr', type=float, default=0.0003, metavar='G', 16 | help='learning rate (default: 0.0003)') 17 | parser.add_argument('--alpha', type=float, default=0.2, metavar='G', 18 | help='Temperature parameter α determines the relative importance of the entropy term against the reward (default: 0.2)') 19 | parser.add_argument('--automatic_entropy_tuning', type=bool, default=False, metavar='G', 20 | help='Temperature parameter α automaically adjusted.') 21 | parser.add_argument('--seed', type=int, default=456, metavar='N', 22 | help='random seed (default: 456)') 23 | parser.add_argument('--batch-size', type=int, default=256, metavar='N', 24 | help='batch size (default: 256)') 25 | parser.add_argument('--num-steps', type=int, default=1000001, metavar='N', 26 | help='maximum number of steps (default: 1000000)') 27 | parser.add_argument('--hidden-size', type=int, default=400, metavar='N', 28 | help='hidden size (default: 256)') 29 | parser.add_argument('--updates-per-step', type=int, default=1, metavar='N', 30 | help='model updates per simulator step (default: 1)') 31 | parser.add_argument('--start-steps', type=int, default=300, metavar='N', 32 | help='Steps sampling random actions (default: 10000)') 33 | parser.add_argument('--target-update-interval', type=int, default=1, metavar='N', 34 | help='Value target update per no. of updates per step (default: 1)') 35 | parser.add_argument('--replay-size', type=int, default=1000000, metavar='N', 36 | help='size of replay buffer (default: 10000000)') 37 | parser.add_argument('--device', type=str, default="cuda:0", 38 | help='run on CUDA (default: False)') 39 | parser.add_argument('--resume', type=bool, default=False, 40 | help='run on CUDA (default: False)') 41 | parser.add_argument('--model-path', type=str, default='learned_models/sac') 42 | parser.add_argument('--expert-traj-path', metavar='G', 43 | help='path of the expert trajectories') 44 | parser.add_argument('--rnd-epoch', type=int, default=600, 45 | help='path of the expert trajectories') 46 | args = parser.parse_args() 47 | 48 | return args -------------------------------------------------------------------------------- /expert/save_traj_ppo.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import gym 3 | import os 4 | import sys 5 | import pickle 6 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 7 | 8 | from itertools import count 9 | from utils import * 10 | 11 | parser = argparse.ArgumentParser(description='Save expert trajectory') 12 | parser.add_argument('--env-name', default="Hopper-v2", metavar='G', 13 | help='name of the environment to run') 14 | parser.add_argument('--model-path', metavar='G', 15 | help='name of the expert model') 16 | parser.add_argument('--render', action='store_true', default=False, 17 | help='render the environment') 18 | parser.add_argument('--seed', type=int, default=1, metavar='N', 19 | help='random seed (default: 1)') 20 | parser.add_argument('--max-expert-state-num', type=int, default=50000, metavar='N', 21 | help='maximal number of main iterations (default: 50000)') 22 | parser.add_argument('--running-state', type=int, default=0) 23 | args = parser.parse_args() 24 | 25 | dtype = torch.float32 26 | torch.set_default_dtype(dtype) 27 | env = gym.make(args.env_name) 28 | env.seed(args.seed) 29 | torch.manual_seed(args.seed) 30 | is_disc_action = len(env.action_space.shape) == 0 31 | state_dim = env.observation_space.shape[0] 32 | 33 | if args.running_state == 1: 34 | print('use running state') 35 | policy_net, _, running_state = pickle.load(open(args.model_path, "rb")) 36 | else: 37 | print('no running state') 38 | policy_net, _ = pickle.load(open(args.model_path, "rb")) 39 | 40 | expert_trajs = [] 41 | policy_net.to(dtype) 42 | def main_loop(): 43 | 44 | num_steps = 0 45 | 46 | for i_episode in count(): 47 | expert_traj = [] 48 | state = env.reset() 49 | if args.running_state: 50 | state = running_state(state) 51 | reward_episode = 0 52 | 53 | for t in range(10000): 54 | state_var = tensor(state).unsqueeze(0).to(dtype) 55 | # choose mean action 56 | action = policy_net(state_var)[0][0].detach().numpy() 57 | # choose stochastic action 58 | # action = policy_net.select_action(state_var)[0].cpu().numpy() 59 | action = int(action) if is_disc_action else action.astype(np.float64) 60 | next_state, reward, done, _ = env.step(action) 61 | if args.running_state: 62 | next_state = running_state(next_state) 63 | reward_episode += reward 64 | num_steps += 1 65 | 66 | expert_traj.append(np.hstack([state, action])) 67 | 68 | if args.render: 69 | env.render() 70 | if done: 71 | expert_traj = np.stack(expert_traj) 72 | expert_trajs.append(expert_traj) 73 | break 74 | 75 | state = next_state 76 | 77 | print('Episode {}\t reward: {:.2f}'.format(i_episode, reward_episode)) 78 | 79 | if num_steps >= args.max_expert_state_num: 80 | break 81 | 82 | 83 | main_loop() 84 | if args.running_state: 85 | pickle.dump((expert_trajs, running_state), open(os.path.join(assets_dir(), 'expert_traj/{}_expert_traj.p'.format(args.env_name)), \ 86 | 'wb')) 87 | else: 88 | pickle.dump(expert_trajs, open(os.path.join(assets_dir(), 'expert_traj/{}_expert_traj.p'.format(args.env_name)),\ 89 | 'wb')) 90 | -------------------------------------------------------------------------------- /expert/save_traj_sac.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import gym 3 | import itertools 4 | from agents.sac_agent import SAC_agent 5 | from utils import * 6 | import argparse 7 | 8 | def get_args(): 9 | parser = argparse.ArgumentParser(description='PyTorch GAIL example') 10 | parser.add_argument('--env-name', default="Hopper-v2", 11 | help='name of the environment to run') 12 | parser.add_argument('--policy', default="Gaussian", 13 | help='algorithm to use: Gaussian | Deterministic') 14 | parser.add_argument('--eval', type=bool, default=True, 15 | help='Evaluates a policy a policy every 10 episode (default:True)') 16 | parser.add_argument('--gamma', type=float, default=0.99, metavar='G', 17 | help='discount factor for reward (default: 0.99)') 18 | parser.add_argument('--tau', type=float, default=0.005, metavar='G', 19 | help='target smoothing coefficient(τ) (default: 0.005)') 20 | parser.add_argument('--lr', type=float, default=0.0003, metavar='G', 21 | help='learning rate (default: 0.0003)') 22 | parser.add_argument('--alpha', type=float, default=0.2, metavar='G', 23 | help='Temperature parameter α determines the relative importance of the entropy term against the reward (default: 0.2)') 24 | parser.add_argument('--automatic_entropy_tuning', type=bool, default=False, metavar='G', 25 | help='Temperature parameter α automaically adjusted.') 26 | parser.add_argument('--seed', type=int, default=456, metavar='N', 27 | help='random seed (default: 456)') 28 | parser.add_argument('--batch-size', type=int, default=256, metavar='N', 29 | help='batch size (default: 256)') 30 | parser.add_argument('--num-steps', type=int, default=1000001, metavar='N', 31 | help='maximum number of steps (default: 1000000)') 32 | parser.add_argument('--hidden-size', type=int, default=400, metavar='N', 33 | help='hidden size (default: 256)') 34 | parser.add_argument('--updates-per-step', type=int, default=1, metavar='N', 35 | help='model updates per simulator step (default: 1)') 36 | parser.add_argument('--start-steps', type=int, default=300, metavar='N', 37 | help='Steps sampling random actions (default: 10000)') 38 | parser.add_argument('--target-update-interval', type=int, default=1, metavar='N', 39 | help='Value target update per no. of updates per step (default: 1)') 40 | parser.add_argument('--replay-size', type=int, default=1e6, metavar='N', 41 | help='size of replay buffer (default: 10000000)') 42 | parser.add_argument('--device', type=str, default="cuda:0", 43 | help='run on CUDA (default: False)') 44 | parser.add_argument('--actor-path', type=str, default='assets/learned_models/sac_actor_Hopper-v2_1', help='actor resume path') 45 | parser.add_argument('--critic-path', type=str, default='assets/learned_models/sac_critic_Hopper-v2_1', help='critic resume path') 46 | 47 | args = parser.parse_args() 48 | 49 | return args 50 | 51 | args = get_args() 52 | 53 | # Environment 54 | # env = NormalizedActions(gym.make(args.env_name)) 55 | env = gym.make(args.env_name) 56 | torch.manual_seed(args.seed) 57 | np.random.seed(args.seed) 58 | env.seed(args.seed) 59 | state_dim = env.observation_space.shape[0] 60 | agent = SAC_agent(env, env.observation_space.shape[0], env.action_space, args, running_state=None) 61 | agent.load_model(actor_path=args.actor_path, critic_path=args.critic_path) 62 | agent.save_expert_traj(max_step=50000) -------------------------------------------------------------------------------- /expert/train_sac.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import gym 3 | import itertools 4 | from agents.sac_agent import SAC_agent 5 | from tensorboardX import SummaryWriter 6 | from models.rnd_model import RND 7 | from expert.sac_args import get_sac_args 8 | from utils import * 9 | 10 | args = get_sac_args() 11 | 12 | # Environment 13 | # env = NormalizedActions(gym.make(args.env_name)) 14 | env = gym.make(args.env_name) 15 | torch.manual_seed(args.seed) 16 | np.random.seed(args.seed) 17 | env.seed(args.seed) 18 | state_dim = env.observation_space.shape[0] 19 | agent = SAC_agent(env, env.observation_space.shape[0], env.action_space, args, running_state=None) 20 | writer = SummaryWriter(log_dir='runs/{}_SAC_{}_{}_{}'.format(datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"), args.env_name, 21 | args.policy, "autotune" if args.automatic_entropy_tuning else "")) 22 | # Memory 23 | memory = Memory(args.replay_size) 24 | 25 | # Training Loop 26 | total_numsteps = 0 27 | updates = 0 28 | #agent.save_expert_traj() 29 | for i_episode in itertools.count(1): 30 | episode_reward = 0 31 | episode_steps = 0 32 | done = False 33 | state = env.reset() 34 | 35 | while not done: 36 | if args.start_steps > total_numsteps: 37 | action = env.action_space.sample() # Sample random action 38 | else: 39 | action = agent.select_action(state) # Sample action from policy 40 | 41 | if len(memory) > args.batch_size: 42 | # Number of updates per step in environment 43 | for i in range(args.updates_per_step): 44 | # Update parameters of all the networks 45 | critic_1_loss, critic_2_loss, policy_loss, ent_loss, alpha = agent.update_parameters(memory, args.batch_size, updates) 46 | 47 | writer.add_scalar('loss/critic_1', critic_1_loss, updates) 48 | writer.add_scalar('loss/critic_2', critic_2_loss, updates) 49 | writer.add_scalar('loss/policy', policy_loss, updates) 50 | writer.add_scalar('loss/entropy_loss', ent_loss, updates) 51 | writer.add_scalar('entropy_temprature/alpha', alpha, updates) 52 | updates += 1 53 | 54 | next_state, reward, done, _ = env.step(action) # Step 55 | episode_steps += 1 56 | total_numsteps += 1 57 | episode_reward += reward 58 | 59 | mask = 1 if episode_steps == env._max_episode_steps else float(not done) 60 | 61 | memory.push(state, action, reward, next_state, mask) # Append transition to memory 62 | 63 | state = next_state 64 | 65 | if total_numsteps > args.num_steps: 66 | break 67 | 68 | writer.add_scalar('reward/sac_train', episode_reward, total_numsteps) 69 | print("Episode: {}, total numsteps: {}, episode steps: {}, reward: {}".format(i_episode, total_numsteps, episode_steps, round(episode_reward, 2))) 70 | 71 | if i_episode % 10 == 0 and args.eval == True: 72 | avg_reward = 0. 73 | episodes = 10 74 | for _ in range(episodes): 75 | state = env.reset() 76 | episode_reward = 0 77 | done = False 78 | while not done: 79 | action = agent.select_action(state, eval=True) 80 | 81 | next_state, reward, done, _ = env.step(action) 82 | episode_reward += reward 83 | 84 | 85 | state = next_state 86 | avg_reward += episode_reward 87 | avg_reward /= episodes 88 | writer.add_scalar('avg_reward/sac_test', avg_reward, total_numsteps) 89 | 90 | print("----------------------------------------") 91 | print("Test Episodes: {}, Avg. Reward: {}".format(episodes, round(avg_reward, 2))) 92 | print("----------------------------------------") 93 | 94 | if i_episode % 1000 == 1: 95 | agent.save_model(args.env_name, i_episode) 96 | # get trajectory of the expert 97 | # agent.save_expert_traj(max_step=50000) 98 | 99 | env.close() -------------------------------------------------------------------------------- /envs/mujoco/swimmer.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | from __future__ import absolute_import 4 | 5 | import os 6 | import numpy as np 7 | from gym import utils 8 | from gym.envs.mujoco import mujoco_env 9 | 10 | class DisableSwimmerEnv(mujoco_env.MujocoEnv, utils.EzPickle): 11 | def __init__(self): 12 | dir_path = os.path.dirname(os.path.realpath(__file__)) 13 | mujoco_env.MujocoEnv.__init__(self, '%s/assets/disabled_swimmer.xml' % dir_path, 4) 14 | utils.EzPickle.__init__(self) 15 | 16 | def step(self, a): 17 | ctrl_cost_coeff = 0.0001 18 | xposbefore = self.sim.data.qpos[0] 19 | self.do_simulation(a, self.frame_skip) 20 | xposafter = self.sim.data.qpos[0] 21 | reward_fwd = (xposafter - xposbefore) / self.dt 22 | reward_ctrl = - ctrl_cost_coeff * np.square(a).sum() 23 | reward = reward_fwd + reward_ctrl 24 | ob = self._get_obs() 25 | return ob, reward, False, dict(reward_fwd=reward_fwd, reward_ctrl=reward_ctrl) 26 | 27 | def _get_obs(self): 28 | qpos = self.sim.data.qpos 29 | qvel = self.sim.data.qvel 30 | return np.concatenate([qpos.flat[2:], qvel.flat]) 31 | 32 | def reset_model(self): 33 | self.set_state( 34 | self.init_qpos + self.np_random.uniform(low=-.1, high=.1, size=self.model.nq), 35 | self.init_qvel + self.np_random.uniform(low=-.1, high=.1, size=self.model.nv) 36 | ) 37 | return self._get_obs() 38 | 39 | class HeavySwimmerEnv(mujoco_env.MujocoEnv, utils.EzPickle): 40 | def __init__(self): 41 | dir_path = os.path.dirname(os.path.realpath(__file__)) 42 | mujoco_env.MujocoEnv.__init__(self, '%s/assets/heavy_swimmer.xml' % dir_path, 4) 43 | utils.EzPickle.__init__(self) 44 | 45 | def step(self, a): 46 | ctrl_cost_coeff = 0.0001 47 | xposbefore = self.sim.data.qpos[0] 48 | self.do_simulation(a, self.frame_skip) 49 | xposafter = self.sim.data.qpos[0] 50 | reward_fwd = (xposafter - xposbefore) / self.dt 51 | reward_ctrl = - ctrl_cost_coeff * np.square(a).sum() 52 | reward = reward_fwd + reward_ctrl 53 | ob = self._get_obs() 54 | return ob, reward, False, dict(reward_fwd=reward_fwd, reward_ctrl=reward_ctrl) 55 | 56 | def _get_obs(self): 57 | qpos = self.sim.data.qpos 58 | qvel = self.sim.data.qvel 59 | return np.concatenate([qpos.flat[2:], qvel.flat]) 60 | 61 | def reset_model(self): 62 | self.set_state( 63 | self.init_qpos + self.np_random.uniform(low=-.1, high=.1, size=self.model.nq), 64 | self.init_qvel + self.np_random.uniform(low=-.1, high=.1, size=self.model.nv) 65 | ) 66 | return self._get_obs() 67 | 68 | class LightSwimmerEnv(mujoco_env.MujocoEnv, utils.EzPickle): 69 | def __init__(self): 70 | dir_path = os.path.dirname(os.path.realpath(__file__)) 71 | mujoco_env.MujocoEnv.__init__(self, '%s/assets/light_swimmer.xml' % dir_path, 4) 72 | utils.EzPickle.__init__(self) 73 | 74 | def step(self, a): 75 | ctrl_cost_coeff = 0.0001 76 | xposbefore = self.sim.data.qpos[0] 77 | self.do_simulation(a, self.frame_skip) 78 | xposafter = self.sim.data.qpos[0] 79 | reward_fwd = (xposafter - xposbefore) / self.dt 80 | reward_ctrl = - ctrl_cost_coeff * np.square(a).sum() 81 | reward = reward_fwd + reward_ctrl 82 | ob = self._get_obs() 83 | return ob, reward, False, dict(reward_fwd=reward_fwd, reward_ctrl=reward_ctrl) 84 | 85 | def _get_obs(self): 86 | qpos = self.sim.data.qpos 87 | qvel = self.sim.data.qvel 88 | return np.concatenate([qpos.flat[2:], qvel.flat]) 89 | 90 | def reset_model(self): 91 | self.set_state( 92 | self.init_qpos + self.np_random.uniform(low=-.1, high=.1, size=self.model.nq), 93 | self.init_qvel + self.np_random.uniform(low=-.1, high=.1, size=self.model.nv) 94 | ) 95 | return self._get_obs() -------------------------------------------------------------------------------- /models/VAE.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import numpy as np 4 | import torch.nn.functional as F 5 | import torchvision 6 | from torchvision import transforms 7 | import torch.optim as optim 8 | from torch import nn 9 | import matplotlib.pyplot as plt 10 | from torch.distributions import Normal 11 | from models.ppo_models import weights_init_ 12 | 13 | MAX_LOG_STD = 0.5 14 | MIN_LOG_STD = -20 15 | 16 | def latent_loss(z_mean, z_stddev): 17 | mean_sq = z_mean * z_mean 18 | stddev_sq = z_stddev * z_stddev 19 | return 0.5 * torch.mean(mean_sq + stddev_sq - torch.log(stddev_sq) - 1) 20 | 21 | class Encoder(torch.nn.Module): 22 | def __init__(self, input_dim, latent_dim, hidden_size=128): 23 | super(Encoder, self).__init__() 24 | self.linear1 = nn.Linear(input_dim, hidden_size) 25 | self.linear2 = nn.Linear(hidden_size, hidden_size) 26 | self.linear3 = nn.Linear(hidden_size, hidden_size) 27 | self.mu = nn.Linear(hidden_size, latent_dim) 28 | self.log_std = nn.Linear(hidden_size, latent_dim) 29 | self.apply(weights_init_) 30 | 31 | def forward(self, x): 32 | x = F.relu(self.linear1(x)) 33 | x = F.relu(self.linear2(x)) 34 | x = F.relu(self.linear3(x)) 35 | mean = self.mu(x) 36 | log_std = self.log_std(x) 37 | log_std = torch.clamp(log_std, min=MIN_LOG_STD, max=MAX_LOG_STD) 38 | return mean, log_std 39 | 40 | 41 | class Decoder(torch.nn.Module): 42 | def __init__(self, input_dim, out_dim, hidden_size=128): 43 | super(Decoder, self).__init__() 44 | self.linear1 = torch.nn.Linear(input_dim, hidden_size) 45 | self.linear2 = torch.nn.Linear(hidden_size, hidden_size) 46 | self.linear3 = torch.nn.Linear(hidden_size, out_dim) 47 | self.apply(weights_init_) 48 | 49 | def forward(self, x): 50 | x = F.relu(self.linear1(x)) 51 | x = F.relu(self.linear2(x)) 52 | x = self.linear3(x) 53 | return x 54 | 55 | 56 | class VAE(torch.nn.Module): 57 | def __init__(self, state_dim, hidden_size=128, latent_dim=128): 58 | super(VAE, self).__init__() 59 | self.hidden_size = hidden_size 60 | self.encoder = Encoder(state_dim, latent_dim=latent_dim, hidden_size=self.hidden_size) 61 | self.decoder = Decoder(latent_dim, state_dim, hidden_size=self.hidden_size) 62 | 63 | def forward(self, state): 64 | mu, log_sigma = self.encoder(state) 65 | sigma = torch.exp(log_sigma) 66 | sample = mu + torch.randn_like(mu)*sigma 67 | self.z_mean = mu 68 | self.z_sigma = sigma 69 | 70 | return self.decoder(sample) 71 | 72 | def to(self, device): 73 | self.encoder.to(device) 74 | self.decoder.to(device) 75 | 76 | def get_next_states(self, states): 77 | mu, log_sigma = self.encoder(states) 78 | return self.decoder(mu) 79 | 80 | def get_loss(self, state, next_state): 81 | next_pred = self.get_next_states(state) 82 | return ((next_state-next_pred)**2).mean() 83 | 84 | def train(self, input, target, epoch, optimizer, batch_size=128, beta=0.1): 85 | idxs = np.arange(input.shape[0]) 86 | np.random.shuffle(idxs) 87 | num_batch = int(np.ceil(idxs.shape[-1] / batch_size)) 88 | for epoch in range(epoch): 89 | idxs = np.arange(input.shape[0]) 90 | np.random.shuffle(idxs) 91 | for batch_num in range(num_batch): 92 | batch_idxs = idxs[batch_num * batch_size : (batch_num + 1) * batch_size] 93 | train_in = input[batch_idxs].float() 94 | train_targ = target[batch_idxs].float() 95 | optimizer.zero_grad() 96 | dec = self.forward(train_in) 97 | reconstruct_loss = ((train_targ-dec)**2).mean() 98 | ll = latent_loss(self.z_mean, self.z_sigma) 99 | loss = reconstruct_loss + beta*ll 100 | loss.backward() 101 | optimizer.step() 102 | val_input = input[idxs] 103 | val_target = target[idxs] 104 | val_dec = self.get_next_states(val_input) 105 | loss = ((val_target-val_dec)**2).mean().item() 106 | #print('vae loss', loss) 107 | return loss -------------------------------------------------------------------------------- /core/agent.py: -------------------------------------------------------------------------------- 1 | from utils.replay_memory import OnlineMemory 2 | from utils.torch import * 3 | import math 4 | import time 5 | 6 | def collect_samples(env, policy, custom_reward, 7 | mean_action, render, running_state, min_batch_size, update, trajectory, dtype=torch.float32): 8 | log = dict() 9 | memory = OnlineMemory() 10 | num_steps = 0 11 | total_reward = 0 12 | min_reward = 1e6 13 | max_reward = -1e6 14 | total_c_reward = 0 15 | min_c_reward = 1e6 16 | max_c_reward = -1e6 17 | num_episodes = 0 18 | if dtype == torch.float32: 19 | numpy_dtype = np.float32 20 | else: 21 | numpy_dtype = np.float64 22 | 23 | while num_steps < min_batch_size: 24 | state = env.reset() 25 | episode_steps = 0 26 | if running_state is not None: 27 | state = running_state(state, update=update) 28 | reward_episode = 0 29 | 30 | for t in range(10000): 31 | state_var = tensor(state).to(dtype).unsqueeze(0) 32 | with torch.no_grad(): 33 | if mean_action: 34 | action = policy(state_var)[0][0].numpy() 35 | else: 36 | action = policy.select_action(state_var)[0].numpy() 37 | action = int(action) if policy.is_disc_action else action.astype(numpy_dtype) 38 | next_state, reward, done, _ = env.step(action) 39 | episode_steps += 1 40 | reward_episode += reward 41 | 42 | if running_state is not None: 43 | next_state = running_state(next_state, update=update) 44 | #print(next_state, running_state.rs.n) 45 | 46 | if episode_steps == env._max_episode_steps: 47 | mask = 1. 48 | else: 49 | mask = float(not done) 50 | 51 | if custom_reward is not None: 52 | reward = custom_reward(state, action, next_state, done) 53 | total_c_reward += reward 54 | min_c_reward = min(min_c_reward, reward) 55 | max_c_reward = max(max_c_reward, reward) 56 | 57 | memory.push(state, action, reward, next_state, mask) 58 | if trajectory is not None: 59 | trajectory.push(state, action, reward, next_state, mask) 60 | 61 | if render: 62 | env.render() 63 | if done: 64 | if trajectory is not None: 65 | trajectory.clear() 66 | break 67 | 68 | state = next_state 69 | 70 | # log stats 71 | num_steps += (t + 1) 72 | num_episodes += 1 73 | total_reward += reward_episode 74 | min_reward = min(min_reward, reward_episode) 75 | max_reward = max(max_reward, reward_episode) 76 | 77 | log['num_steps'] = num_steps 78 | log['num_episodes'] = num_episodes 79 | log['total_reward'] = total_reward 80 | log['avg_reward'] = total_reward / num_episodes 81 | log['max_reward'] = max_reward 82 | log['min_reward'] = min_reward 83 | if custom_reward is not None: 84 | log['total_c_reward'] = total_c_reward 85 | log['avg_c_reward'] = total_c_reward / num_episodes 86 | log['max_c_reward'] = max_c_reward 87 | log['min_c_reward'] = min_c_reward 88 | 89 | return memory, log 90 | 91 | class Agent: 92 | 93 | def __init__(self, env, policy, device, custom_reward=None, 94 | mean_action=False, render=False, running_state=None, update=True, dtype=torch.float32): 95 | self.env = env 96 | self.policy = policy 97 | self.device = device 98 | self.custom_reward = custom_reward 99 | self.mean_action = mean_action 100 | self.running_state = running_state 101 | self.render = render 102 | self.update = update 103 | self.dtype = dtype 104 | 105 | def collect_samples(self, min_batch_size, trajectory=None): 106 | t_start = time.time() 107 | to_device(torch.device('cpu'), self.policy) 108 | memory, log = collect_samples(self.env, self.policy, self.custom_reward, self.mean_action, 109 | self.render, self.running_state, min_batch_size, self.update, trajectory, self.dtype) 110 | 111 | batch = memory.sample() 112 | to_device(self.device, self.policy) 113 | t_end = time.time() 114 | log['sample_time'] = t_end - t_start 115 | log['action_mean'] = np.mean(np.vstack(batch.action), axis=0) 116 | log['action_min'] = np.min(np.vstack(batch.action), axis=0) 117 | log['action_max'] = np.max(np.vstack(batch.action), axis=0) 118 | return batch, log 119 | -------------------------------------------------------------------------------- /models/sac_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.distributions import Normal 5 | 6 | LOG_SIG_MAX = 2 7 | LOG_SIG_MIN = -20 8 | epsilon = 1e-6 9 | 10 | # Initialize Policy weights 11 | def weights_init_(m): 12 | if isinstance(m, nn.Linear): 13 | torch.nn.init.xavier_uniform_(m.weight, gain=1) 14 | torch.nn.init.constant_(m.bias, 0) 15 | 16 | 17 | class QNetwork(nn.Module): 18 | def __init__(self, num_inputs, num_actions, hidden_dim): 19 | super(QNetwork, self).__init__() 20 | 21 | # Q1 architecture 22 | self.linear1 = nn.Linear(num_inputs + num_actions, hidden_dim) 23 | self.linear2 = nn.Linear(hidden_dim, hidden_dim) 24 | self.linear3 = nn.Linear(hidden_dim, 1) 25 | 26 | # Q2 architecture 27 | self.linear4 = nn.Linear(num_inputs + num_actions, hidden_dim) 28 | self.linear5 = nn.Linear(hidden_dim, hidden_dim) 29 | self.linear6 = nn.Linear(hidden_dim, 1) 30 | 31 | self.apply(weights_init_) 32 | 33 | def forward(self, state, action): 34 | xu = torch.cat([state, action], 1) 35 | 36 | x1 = F.relu(self.linear1(xu)) 37 | x1 = F.relu(self.linear2(x1)) 38 | x1 = self.linear3(x1) 39 | 40 | x2 = F.relu(self.linear4(xu)) 41 | x2 = F.relu(self.linear5(x2)) 42 | x2 = self.linear6(x2) 43 | 44 | return x1, x2 45 | 46 | 47 | class GaussianPolicy(nn.Module): 48 | def __init__(self, num_inputs, num_actions, hidden_dim, action_space=None): 49 | super(GaussianPolicy, self).__init__() 50 | 51 | self.linear1 = nn.Linear(num_inputs, hidden_dim) 52 | self.linear2 = nn.Linear(hidden_dim, hidden_dim) 53 | 54 | self.mean_linear = nn.Linear(hidden_dim, num_actions) 55 | self.log_std_linear = nn.Linear(hidden_dim, num_actions) 56 | 57 | self.apply(weights_init_) 58 | 59 | # action rescaling 60 | if action_space is None: 61 | self.action_scale = torch.tensor(1.) 62 | self.action_bias = torch.tensor(0.) 63 | else: 64 | self.action_scale = torch.FloatTensor( 65 | (action_space.high - action_space.low) / 2.) 66 | self.action_bias = torch.FloatTensor( 67 | (action_space.high + action_space.low) / 2.) 68 | 69 | def forward(self, state): 70 | x = F.relu(self.linear1(state)) 71 | x = F.relu(self.linear2(x)) 72 | mean = self.mean_linear(x) 73 | log_std = self.log_std_linear(x) 74 | log_std = torch.clamp(log_std, min=LOG_SIG_MIN, max=LOG_SIG_MAX) 75 | return mean, log_std 76 | 77 | def sample(self, state): 78 | mean, log_std = self.forward(state) 79 | std = log_std.exp() 80 | normal = Normal(mean, std) 81 | x_t = normal.rsample() # for reparameterization trick (mean + std * N(0,1)) 82 | y_t = torch.tanh(x_t) 83 | action = y_t * self.action_scale + self.action_bias 84 | log_prob = normal.log_prob(x_t) 85 | # Enforcing Action Bound 86 | log_prob -= torch.log(self.action_scale * (1 - y_t.pow(2)) + epsilon) 87 | log_prob = log_prob.sum(1, keepdim=True) 88 | return action, log_prob, torch.tanh(mean) 89 | 90 | def to(self, device): 91 | self.action_scale = self.action_scale.to(device) 92 | self.action_bias = self.action_bias.to(device) 93 | return super(GaussianPolicy, self).to(device) 94 | 95 | 96 | class DeterministicPolicy(nn.Module): 97 | def __init__(self, num_inputs, num_actions, hidden_dim, action_space=None): 98 | super(DeterministicPolicy, self).__init__() 99 | self.linear1 = nn.Linear(num_inputs, hidden_dim) 100 | self.linear2 = nn.Linear(hidden_dim, hidden_dim) 101 | 102 | self.mean = nn.Linear(hidden_dim, num_actions) 103 | self.noise = torch.Tensor(num_actions) 104 | 105 | self.apply(weights_init_) 106 | 107 | # action rescaling 108 | if action_space is None: 109 | self.action_scale = 1. 110 | self.action_bias = 0. 111 | else: 112 | self.action_scale = torch.FloatTensor( 113 | (action_space.high - action_space.low) / 2.) 114 | self.action_bias = torch.FloatTensor( 115 | (action_space.high + action_space.low) / 2.) 116 | 117 | def forward(self, state): 118 | x = F.relu(self.linear1(state)) 119 | x = F.relu(self.linear2(x)) 120 | mean = torch.tanh(self.mean(x)) * self.action_scale + self.action_bias 121 | return mean 122 | 123 | def sample(self, state): 124 | mean = self.forward(state) 125 | noise = self.noise.normal_(0., std=0.1) 126 | noise = noise.clamp(-0.25, 0.25) 127 | action = mean + noise 128 | return action, torch.tensor(0.), mean 129 | 130 | def to(self, device): 131 | self.action_scale = self.action_scale.to(device) 132 | self.action_bias = self.action_bias.to(device) 133 | return super(DeterministicPolicy, self).to(device) 134 | -------------------------------------------------------------------------------- /envs/mujoco/assets/ant.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 2 | 3 |