├── .gitignore ├── README.md ├── core ├── __init__.py ├── agent │ ├── base.py │ └── in_sample.py ├── environment │ ├── __init__.py │ ├── acrobot.py │ ├── ant.py │ ├── env_factory.py │ ├── halfcheetah.py │ ├── hopper.py │ ├── lunarlander.py │ ├── mountaincar.py │ └── walker2d.py ├── network │ ├── __init__.py │ ├── network_architectures.py │ ├── network_bodies.py │ ├── network_utils.py │ └── policy_factory.py └── utils │ ├── __init__.py │ ├── helpers.py │ ├── logger.py │ ├── run_funcs.py │ └── torch_utils.py ├── img └── after_fix.png └── run_ac_offline.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # Pycharm 121 | .idea 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # Data 135 | data/ 136 | output/ 137 | plot/img 138 | 139 | # Cache 140 | *__pycache__* 141 | *.pyc 142 | 143 | #CMD 144 | cmd*.sh -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This is a code release for our paper 'The In-Sample Softmax for Offline Reinforcement Learning' (https://openreview.net/pdf?id=u-RuvyDYqCM). 2 | 3 | # Running the code: 4 | 5 | ``` 6 | python run_ac_offline.py --seed 0 --env_name Ant --dataset expert --discrete_control 0 --state_dim 111 --action_dim 8 --tau 0.01 --learning_rate 0.0003 --hidden_units 256 --batch_size 256 --timeout 1000 --max_steps 1000000 --log_interval 10000 7 | 8 | python run_ac_offline.py --seed 0 --env_name Ant --dataset medexp --discrete_control 0 --state_dim 111 --action_dim 8 --tau 0.01 --learning_rate 0.0003 --hidden_units 256 --batch_size 256 --timeout 1000 --max_steps 1000000 --log_interval 10000 9 | 10 | python run_ac_offline.py --seed 0 --env_name Ant --dataset medium --discrete_control 0 --state_dim 111 --action_dim 8 --tau 0.5 --learning_rate 0.0003 --hidden_units 256 --batch_size 256 --timeout 1000 --max_steps 1000000 --log_interval 10000 11 | 12 | python run_ac_offline.py --seed 0 --env_name Ant --dataset medrep --discrete_control 0 --state_dim 111 --action_dim 8 --tau 0.5 --learning_rate 0.0003 --hidden_units 256 --batch_size 256 --timeout 1000 --max_steps 1000000 --log_interval 10000 13 | 14 | python run_ac_offline.py --seed 0 --env_name HalfCheetah --dataset expert --discrete_control 0 --state_dim 17 --action_dim 6 --tau 0.01 --learning_rate 0.0003 --hidden_units 256 --batch_size 256 --timeout 1000 --max_steps 1000000 --log_interval 10000 15 | 16 | python run_ac_offline.py --seed 0 --env_name HalfCheetah --dataset medexp --discrete_control 0 --state_dim 17 --action_dim 6 --tau 0.1 --learning_rate 0.0003 --hidden_units 256 --batch_size 256 --timeout 1000 --max_steps 1000000 --log_interval 10000 17 | 18 | python run_ac_offline.py --seed 0 --env_name HalfCheetah --dataset medium --discrete_control 0 --state_dim 17 --action_dim 6 --tau 0.33 --learning_rate 0.0003 --hidden_units 256 --batch_size 256 --timeout 1000 --max_steps 1000000 --log_interval 10000 19 | 20 | python run_ac_offline.py --seed 0 --env_name HalfCheetah --dataset medrep --discrete_control 0 --state_dim 17 --action_dim 6 --tau 0.5 --learning_rate 0.0003 --hidden_units 256 --batch_size 256 --timeout 1000 --max_steps 1000000 --log_interval 10000 21 | 22 | python run_ac_offline.py --seed 0 --env_name Hopper --dataset expert --discrete_control 0 --state_dim 11 --action_dim 3 --tau 0.01 --learning_rate 0.0003 --hidden_units 256 --batch_size 256 --timeout 1000 --max_steps 1000000 --log_interval 10000 23 | 24 | python run_ac_offline.py --seed 0 --env_name Hopper --dataset medexp --discrete_control 0 --state_dim 11 --action_dim 3 --tau 0.01 --learning_rate 0.0003 --hidden_units 256 --batch_size 256 --timeout 1000 --max_steps 1000000 --log_interval 10000 25 | 26 | python run_ac_offline.py --seed 0 --env_name Hopper --dataset medium --discrete_control 0 --state_dim 11 --action_dim 3 --tau 0.1 --learning_rate 0.0003 --hidden_units 256 --batch_size 256 --timeout 1000 --max_steps 1000000 --log_interval 10000 27 | 28 | python run_ac_offline.py --seed 0 --env_name Hopper --dataset medrep --discrete_control 0 --state_dim 11 --action_dim 3 --tau 0.5 --learning_rate 0.0003 --hidden_units 256 --batch_size 256 --timeout 1000 --max_steps 1000000 --log_interval 10000 29 | 30 | python run_ac_offline.py --seed 0 --env_name Walker2d --dataset expert --discrete_control 0 --state_dim 17 --action_dim 6 --tau 0.01 --learning_rate 0.0003 --hidden_units 256 --batch_size 256 --timeout 1000 --max_steps 1000000 --log_interval 10000 31 | 32 | python run_ac_offline.py --seed 0 --env_name Walker2d --dataset medexp --discrete_control 0 --state_dim 17 --action_dim 6 --tau 0.1 --learning_rate 0.0003 --hidden_units 256 --batch_size 256 --timeout 1000 --max_steps 1000000 --log_interval 10000 33 | 34 | python run_ac_offline.py --seed 0 --env_name Walker2d --dataset medium --discrete_control 0 --state_dim 17 --action_dim 6 --tau 0.33 --learning_rate 0.0003 --hidden_units 256 --batch_size 256 --timeout 1000 --max_steps 1000000 --log_interval 10000 35 | 36 | python run_ac_offline.py --seed 0 --env_name Walker2d --dataset medrep --discrete_control 0 --state_dim 17 --action_dim 6 --tau 0.5 --learning_rate 0.0003 --hidden_units 256 --batch_size 256 --timeout 1000 --max_steps 1000000 --log_interval 10000 37 | ``` 38 | 39 | **Update:** 40 | 41 | We fixed the policy network for continuous control (Thanks for @typoverflow!). We rerun the affected baselines with 5 runs. The hyperparameters have been updated above, and the results are reported below. 42 | The fix **did not** change the **overall performance** and the **conclusions** reported in the paper. 43 | 44 | 45 | 46 | # D4RL installation 47 | If you are using *Ubuntu* and have not got *d4rl* installed yet, this section may help 48 | 49 | 1. Download mujoco 50 | 51 | I am using mujoco210. It can be downloaded from https://github.com/deepmind/mujoco/releases/download/2.1.0/mujoco210-linux-x86_64.tar.gz 52 | ``` 53 | mkdir .mujoco 54 | mv mujoco210-linux-x86_64.tar.gz .mujoco 55 | cd .mujoco 56 | tar -xvzf mujoco210-linux-x86_64.tar.gz 57 | ``` 58 | 59 | Then, add mujoco path: 60 | 61 | Open .bashrc file and add the following line: 62 | ``` 63 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/.mujoco/mujoco210/bin 64 | ``` 65 | 66 | Save the change and run the following command: 67 | ``` 68 | source .bashrc 69 | ``` 70 | 71 | 2. Install other packages and D4RL 72 | ``` 73 | pip install mujoco_py 74 | pip install dm_control==1.0.7 75 | pip install git+https://github.com/Farama-Foundation/d4rl@master#egg=d4rl 76 | ``` 77 | 78 | 3. Test the installation in python 79 | ``` 80 | import gym 81 | import d4rl 82 | env = gym.make('maze2d-umaze-v1') 83 | env.get_dataset() 84 | ``` 85 | -------------------------------------------------------------------------------- /core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwang-ua/inac_pytorch/ca5007bbd59cf53adf0cc588dc5130b836c30622/core/__init__.py -------------------------------------------------------------------------------- /core/agent/base.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import pickle 5 | import torch 6 | import copy 7 | 8 | from core.utils import torch_utils 9 | 10 | 11 | class Replay: 12 | def __init__(self, memory_size, batch_size, seed=0): 13 | self.rng = np.random.RandomState(seed) 14 | self.memory_size = memory_size 15 | self.batch_size = batch_size 16 | self.data = [] 17 | self.pos = 0 18 | 19 | def feed(self, experience): 20 | if self.pos >= len(self.data): 21 | self.data.append(experience) 22 | else: 23 | self.data[self.pos] = experience 24 | self.pos = (self.pos + 1) % self.memory_size 25 | 26 | def feed_batch(self, experience): 27 | for exp in experience: 28 | self.feed(exp) 29 | 30 | def sample(self, batch_size=None): 31 | if batch_size is None: 32 | batch_size = self.batch_size 33 | sampled_indices = [self.rng.randint(0, len(self.data)) for _ in range(batch_size)] 34 | sampled_data = [self.data[ind] for ind in sampled_indices] 35 | batch_data = list(map(lambda x: np.asarray(x), zip(*sampled_data))) 36 | 37 | return batch_data 38 | 39 | def sample_array(self, batch_size=None): 40 | if batch_size is None: 41 | batch_size = self.batch_size 42 | 43 | sampled_indices = [self.rng.randint(0, len(self.data)) for _ in range(batch_size)] 44 | sampled_data = [self.data[ind] for ind in sampled_indices] 45 | 46 | return sampled_data 47 | 48 | def size(self): 49 | return len(self.data) 50 | 51 | def persist_memory(self, dir): 52 | for k in range(len(self.data)): 53 | transition = self.data[k] 54 | with open(os.path.join(dir, str(k)), "wb") as f: 55 | pickle.dump(transition, f) 56 | 57 | def clear(self): 58 | self.data = [] 59 | self.pos = 0 60 | 61 | def get_buffer(self): 62 | return self.data 63 | 64 | 65 | class Agent: 66 | def __init__(self, 67 | exp_path, 68 | seed, 69 | env_fn, 70 | timeout, 71 | gamma, 72 | offline_data, 73 | action_dim, 74 | batch_size, 75 | use_target_network, 76 | target_network_update_freq, 77 | evaluation_criteria, 78 | logger 79 | ): 80 | self.exp_path = exp_path 81 | self.seed = seed 82 | self.use_target_network = use_target_network 83 | self.target_network_update_freq = target_network_update_freq 84 | self.parameters_dir = self.get_parameters_dir() 85 | 86 | self.batch_size = batch_size 87 | self.env = env_fn() 88 | self.eval_env = copy.deepcopy(env_fn)() 89 | self.offline_data = offline_data 90 | self.replay = Replay(memory_size=2000000, batch_size=batch_size, seed=seed) 91 | self.state_normalizer = lambda x: x 92 | self.evaluation_criteria = evaluation_criteria 93 | self.logger = logger 94 | self.timeout = timeout 95 | self.action_dim = action_dim 96 | 97 | self.gamma = gamma 98 | self.device = 'cpu' 99 | self.stats_queue_size = 5 100 | self.episode_reward = 0 101 | self.episode_rewards = [] 102 | self.total_steps = 0 103 | self.reset = True 104 | self.ep_steps = 0 105 | self.num_episodes = 0 106 | self.ep_returns_queue_train = np.zeros(self.stats_queue_size) 107 | self.ep_returns_queue_test = np.zeros(self.stats_queue_size) 108 | self.train_stats_counter = 0 109 | self.test_stats_counter = 0 110 | self.agent_rng = np.random.RandomState(self.seed) 111 | 112 | self.populate_latest = False 113 | self.populate_states, self.populate_actions, self.populate_true_qs = None, None, None 114 | self.automatic_tmp_tuning = False 115 | 116 | self.state = None 117 | self.action = None 118 | self.next_state = None 119 | self.eps = 1e-8 120 | 121 | def get_parameters_dir(self): 122 | d = os.path.join(self.exp_path, "parameters") 123 | torch_utils.ensure_dir(d) 124 | return d 125 | 126 | def offline_param_init(self): 127 | self.trainset = self.training_set_construction(self.offline_data) 128 | self.training_size = len(self.trainset[0]) 129 | self.training_indexs = np.arange(self.training_size) 130 | 131 | self.training_loss = [] 132 | self.test_loss = [] 133 | self.tloss_increase = 0 134 | self.tloss_rec = np.inf 135 | 136 | def get_data(self): 137 | states, actions, rewards, next_states, terminals = self.replay.sample() 138 | in_ = torch_utils.tensor(self.state_normalizer(states), self.device) 139 | r = torch_utils.tensor(rewards, self.device) 140 | ns = torch_utils.tensor(self.state_normalizer(next_states), self.device) 141 | t = torch_utils.tensor(terminals, self.device) 142 | data = { 143 | 'obs': in_, 144 | 'act': actions, 145 | 'reward': r, 146 | 'obs2': ns, 147 | 'done': t 148 | } 149 | return data 150 | 151 | def fill_offline_data_to_buffer(self): 152 | self.trainset = self.training_set_construction(self.offline_data) 153 | train_s, train_a, train_r, train_ns, train_t = self.trainset 154 | for idx in range(len(train_s)): 155 | self.replay.feed([train_s[idx], train_a[idx], train_r[idx], train_ns[idx], train_t[idx]]) 156 | 157 | def step(self): 158 | # trans = self.feed_data() 159 | self.update_stats(0, None) 160 | data = self.get_data() 161 | losses = self.update(data) 162 | return losses 163 | 164 | def update(self, data): 165 | raise NotImplementedError 166 | 167 | def update_stats(self, reward, done): 168 | self.episode_reward += reward 169 | self.total_steps += 1 170 | self.ep_steps += 1 171 | if done or self.ep_steps == self.timeout: 172 | self.episode_rewards.append(self.episode_reward) 173 | self.num_episodes += 1 174 | if self.evaluation_criteria == "return": 175 | self.add_train_log(self.episode_reward) 176 | elif self.evaluation_criteria == "steps": 177 | self.add_train_log(self.ep_steps) 178 | else: 179 | raise NotImplementedError 180 | self.episode_reward = 0 181 | self.ep_steps = 0 182 | self.reset = True 183 | 184 | def add_train_log(self, ep_return): 185 | self.ep_returns_queue_train[self.train_stats_counter] = ep_return 186 | self.train_stats_counter += 1 187 | self.train_stats_counter = self.train_stats_counter % self.stats_queue_size 188 | 189 | def add_test_log(self, ep_return): 190 | self.ep_returns_queue_test[self.test_stats_counter] = ep_return 191 | self.test_stats_counter += 1 192 | self.test_stats_counter = self.test_stats_counter % self.stats_queue_size 193 | 194 | def populate_returns(self, log_traj=False, total_ep=None, initialize=False): 195 | total_ep = self.stats_queue_size if total_ep is None else total_ep 196 | total_steps = 0 197 | total_states = [] 198 | total_actions = [] 199 | total_returns = [] 200 | for ep in range(total_ep): 201 | ep_return, steps, traj = self.eval_episode(log_traj=log_traj) 202 | total_steps += steps 203 | total_states += traj[0] 204 | total_actions += traj[1] 205 | total_returns += traj[2] 206 | if self.evaluation_criteria == "return": 207 | self.add_test_log(ep_return) 208 | if initialize: 209 | self.add_train_log(ep_return) 210 | elif self.evaluation_criteria == "steps": 211 | self.add_test_log(steps) 212 | if initialize: 213 | self.add_train_log(steps) 214 | else: 215 | raise NotImplementedError 216 | return [total_states, total_actions, total_returns] 217 | 218 | def eval_episode(self, log_traj=False): 219 | ep_traj = [] 220 | state = self.eval_env.reset() 221 | total_rewards = 0 222 | ep_steps = 0 223 | done = False 224 | while True: 225 | action = self.eval_step(state) 226 | last_state = state 227 | state, reward, done, _ = self.eval_env.step([action]) 228 | # print(np.abs(state-last_state).sum(), "\n",action) 229 | if log_traj: 230 | ep_traj.append([last_state, action, reward]) 231 | total_rewards += reward 232 | ep_steps += 1 233 | if done or ep_steps == self.timeout: 234 | break 235 | 236 | states = [] 237 | actions = [] 238 | rets = [] 239 | if log_traj: 240 | ret = 0 241 | for i in range(len(ep_traj)-1, -1, -1): 242 | s, a, r = ep_traj[i] 243 | ret = r + self.gamma * ret 244 | rets.insert(0, ret) 245 | actions.insert(0, a) 246 | states.insert(0, s) 247 | return total_rewards, ep_steps, [states, actions, rets] 248 | 249 | def log_return(self, log_ary, name, elapsed_time): 250 | rewards = log_ary 251 | total_episodes = len(self.episode_rewards) 252 | mean, median, min_, max_ = np.mean(rewards), np.median(rewards), np.min(rewards), np.max(rewards) 253 | 254 | log_str = '%s LOG: steps %d, episodes %3d, ' \ 255 | 'returns %.2f/%.2f/%.2f/%.2f/%d (mean/median/min/max/num), %.2f steps/s' 256 | 257 | self.logger.info(log_str % (name, self.total_steps, total_episodes, mean, median, 258 | min_, max_, len(rewards), 259 | elapsed_time)) 260 | return mean, median, min_, max_ 261 | 262 | def log_file(self, elapsed_time=-1, test=True): 263 | mean, median, min_, max_ = self.log_return(self.ep_returns_queue_train, "TRAIN", elapsed_time) 264 | if test: 265 | self.populate_states, self.populate_actions, self.populate_true_qs = self.populate_returns(log_traj=True) 266 | self.populate_latest = True 267 | mean, median, min_, max_ = self.log_return(self.ep_returns_queue_test, "TEST", elapsed_time) 268 | try: 269 | normalized = np.array([self.eval_env.env.unwrapped.get_normalized_score(ret_) for ret_ in self.ep_returns_queue_test]) 270 | mean, median, min_, max_ = self.log_return(normalized, "Normalized", elapsed_time) 271 | except: 272 | pass 273 | return mean, median, min_, max_ 274 | 275 | def policy(self, o, eval=False): 276 | o = torch_utils.tensor(self.state_normalizer(o), self.device) 277 | with torch.no_grad(): 278 | a, _ = self.ac.pi(o, deterministic=eval) 279 | a = torch_utils.to_np(a) 280 | return a 281 | 282 | def eval_step(self, state): 283 | a = self.policy(state, eval=True) 284 | return a 285 | 286 | def training_set_construction(self, data_dict): 287 | assert len(list(data_dict.keys())) == 1 288 | data_dict = data_dict[list(data_dict.keys())[0]] 289 | states = data_dict['states'] 290 | actions = data_dict['actions'] 291 | rewards = data_dict['rewards'] 292 | next_states = data_dict['next_states'] 293 | terminations = data_dict['terminations'] 294 | return [states, actions, rewards, next_states, terminations] 295 | -------------------------------------------------------------------------------- /core/agent/in_sample.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from core.agent import base 3 | from collections import namedtuple 4 | import os 5 | import torch 6 | 7 | from core.network.policy_factory import MLPCont, MLPDiscrete 8 | from core.network.network_architectures import DoubleCriticNetwork, DoubleCriticDiscrete, FCNetwork 9 | 10 | class InSampleAC(base.Agent): 11 | def __init__(self, 12 | device, 13 | discrete_control, 14 | state_dim, 15 | action_dim, 16 | hidden_units, 17 | learning_rate, 18 | tau, 19 | polyak, 20 | exp_path, 21 | seed, 22 | env_fn, 23 | timeout, 24 | gamma, 25 | offline_data, 26 | batch_size, 27 | use_target_network, 28 | target_network_update_freq, 29 | evaluation_criteria, 30 | logger 31 | ): 32 | super(InSampleAC, self).__init__( 33 | exp_path=exp_path, 34 | seed=seed, 35 | env_fn=env_fn, 36 | timeout=timeout, 37 | gamma=gamma, 38 | offline_data=offline_data, 39 | action_dim=action_dim, 40 | batch_size=batch_size, 41 | use_target_network=use_target_network, 42 | target_network_update_freq=target_network_update_freq, 43 | evaluation_criteria=evaluation_criteria, 44 | logger=logger 45 | ) 46 | 47 | def get_policy_func(): 48 | if discrete_control: 49 | pi = MLPDiscrete(device, state_dim, action_dim, [hidden_units]*2) 50 | else: 51 | pi = MLPCont(device, state_dim, action_dim, [hidden_units]*2) 52 | return pi 53 | 54 | def get_critic_func(): 55 | if discrete_control: 56 | q1q2 = DoubleCriticDiscrete(device, state_dim, [hidden_units]*2, action_dim) 57 | else: 58 | q1q2 = DoubleCriticNetwork(device, state_dim, action_dim, [hidden_units]*2) 59 | return q1q2 60 | 61 | pi = get_policy_func() 62 | q1q2 = get_critic_func() 63 | AC = namedtuple('AC', ['q1q2', 'pi']) 64 | self.ac = AC(q1q2=q1q2, pi=pi) 65 | pi_target = get_policy_func() 66 | q1q2_target = get_critic_func() 67 | q1q2_target.load_state_dict(q1q2.state_dict()) 68 | pi_target.load_state_dict(pi.state_dict()) 69 | ACTarg = namedtuple('ACTarg', ['q1q2', 'pi']) 70 | self.ac_targ = ACTarg(q1q2=q1q2_target, pi=pi_target) 71 | self.ac_targ.q1q2.load_state_dict(self.ac.q1q2.state_dict()) 72 | self.ac_targ.pi.load_state_dict(self.ac.pi.state_dict()) 73 | self.beh_pi = get_policy_func() 74 | self.value_net = FCNetwork(device, np.prod(state_dim), [hidden_units]*2, 1) 75 | 76 | self.pi_optimizer = torch.optim.Adam(list(self.ac.pi.parameters()), learning_rate) 77 | self.q_optimizer = torch.optim.Adam(list(self.ac.q1q2.parameters()), learning_rate) 78 | self.value_optimizer = torch.optim.Adam(list(self.value_net.parameters()), learning_rate) 79 | self.beh_pi_optimizer = torch.optim.Adam(list(self.beh_pi.parameters()), learning_rate) 80 | self.exp_threshold = 10000 81 | if discrete_control: 82 | self.get_q_value = self.get_q_value_discrete 83 | self.get_q_value_target = self.get_q_value_target_discrete 84 | else: 85 | self.get_q_value = self.get_q_value_cont 86 | self.get_q_value_target = self.get_q_value_target_cont 87 | 88 | self.tau = tau 89 | self.polyak = polyak 90 | self.fill_offline_data_to_buffer() 91 | self.offline_param_init() 92 | return 93 | 94 | 95 | def compute_loss_beh_pi(self, data): 96 | """L_{\omega}, learn behavior policy""" 97 | states, actions = data['obs'], data['act'] 98 | beh_log_probs = self.beh_pi.get_logprob(states, actions) 99 | beh_loss = -beh_log_probs.mean() 100 | return beh_loss, beh_log_probs 101 | 102 | def compute_loss_value(self, data): 103 | """L_{\phi}, learn z for state value, v = tau log z""" 104 | states = data['obs'] 105 | v_phi = self.value_net(states).squeeze(-1) 106 | with torch.no_grad(): 107 | actions, log_probs = self.ac.pi(states) 108 | min_Q, _, _ = self.get_q_value_target(states, actions) 109 | target = min_Q - self.tau * log_probs 110 | value_loss = (0.5 * (v_phi - target) ** 2).mean() 111 | return value_loss, v_phi.detach().numpy(), log_probs.detach().numpy() 112 | 113 | def get_state_value(self, state): 114 | with torch.no_grad(): 115 | value = self.value_net(state).squeeze(-1) 116 | return value 117 | 118 | def compute_loss_q(self, data): 119 | states, actions, rewards, next_states, dones = data['obs'], data['act'], data['reward'], data['obs2'], data['done'] 120 | with torch.no_grad(): 121 | next_actions, log_probs = self.ac.pi(next_states) 122 | min_Q, _, _ = self.get_q_value_target(next_states, next_actions) 123 | q_target = rewards + self.gamma * (1 - dones) * (min_Q - self.tau * log_probs) 124 | 125 | minq, q1, q2 = self.get_q_value(states, actions, with_grad=True) 126 | 127 | critic1_loss = (0.5 * (q_target - q1) ** 2).mean() 128 | critic2_loss = (0.5 * (q_target - q2) ** 2).mean() 129 | loss_q = (critic1_loss + critic2_loss) * 0.5 130 | q_info = minq.detach().numpy() 131 | return loss_q, q_info 132 | 133 | def compute_loss_pi(self, data): 134 | """L_{\psi}, extract learned policy""" 135 | states, actions = data['obs'], data['act'] 136 | 137 | log_probs = self.ac.pi.get_logprob(states, actions) 138 | min_Q, _, _ = self.get_q_value(states, actions, with_grad=False) 139 | with torch.no_grad(): 140 | value = self.get_state_value(states) 141 | beh_log_prob = self.beh_pi.get_logprob(states, actions) 142 | 143 | clipped = torch.clip(torch.exp((min_Q - value) / self.tau - beh_log_prob), self.eps, self.exp_threshold) 144 | pi_loss = -(clipped * log_probs).mean() 145 | return pi_loss, "" 146 | 147 | def update_beta(self, data): 148 | loss_beh_pi, _ = self.compute_loss_beh_pi(data) 149 | self.beh_pi_optimizer.zero_grad() 150 | loss_beh_pi.backward() 151 | self.beh_pi_optimizer.step() 152 | return loss_beh_pi 153 | 154 | def update(self, data): 155 | loss_beta = self.update_beta(data).item() 156 | 157 | self.value_optimizer.zero_grad() 158 | loss_vs, v_info, logp_info = self.compute_loss_value(data) 159 | loss_vs.backward() 160 | self.value_optimizer.step() 161 | 162 | loss_q, qinfo = self.compute_loss_q(data) 163 | self.q_optimizer.zero_grad() 164 | loss_q.backward() 165 | self.q_optimizer.step() 166 | 167 | loss_pi, _ = self.compute_loss_pi(data) 168 | self.pi_optimizer.zero_grad() 169 | loss_pi.backward() 170 | self.pi_optimizer.step() 171 | 172 | if self.use_target_network and self.total_steps % self.target_network_update_freq == 0: 173 | self.sync_target() 174 | 175 | return {"beta": loss_beta, 176 | "actor": loss_pi.item(), 177 | "critic": loss_q.item(), 178 | "value": loss_vs.item(), 179 | "q_info": qinfo.mean(), 180 | "v_info": v_info.mean(), 181 | "logp_info": logp_info.mean(), 182 | } 183 | 184 | 185 | def get_q_value_discrete(self, o, a, with_grad=False): 186 | if with_grad: 187 | q1_pi, q2_pi = self.ac.q1q2(o) 188 | q1_pi, q2_pi = q1_pi[np.arange(len(a)), a], q2_pi[np.arange(len(a)), a] 189 | q_pi = torch.min(q1_pi, q2_pi) 190 | else: 191 | with torch.no_grad(): 192 | q1_pi, q2_pi = self.ac.q1q2(o) 193 | q1_pi, q2_pi = q1_pi[np.arange(len(a)), a], q2_pi[np.arange(len(a)), a] 194 | q_pi = torch.min(q1_pi, q2_pi) 195 | return q_pi.squeeze(-1), q1_pi.squeeze(-1), q2_pi.squeeze(-1) 196 | 197 | def get_q_value_target_discrete(self, o, a): 198 | with torch.no_grad(): 199 | q1_pi, q2_pi = self.ac_targ.q1q2(o) 200 | q1_pi, q2_pi = q1_pi[np.arange(len(a)), a], q2_pi[np.arange(len(a)), a] 201 | q_pi = torch.min(q1_pi, q2_pi) 202 | return q_pi.squeeze(-1), q1_pi.squeeze(-1), q2_pi.squeeze(-1) 203 | 204 | def get_q_value_cont(self, o, a, with_grad=False): 205 | if with_grad: 206 | q1_pi, q2_pi = self.ac.q1q2(o, a) 207 | q_pi = torch.min(q1_pi, q2_pi) 208 | else: 209 | with torch.no_grad(): 210 | q1_pi, q2_pi = self.ac.q1q2(o, a) 211 | q_pi = torch.min(q1_pi, q2_pi) 212 | return q_pi.squeeze(-1), q1_pi.squeeze(-1), q2_pi.squeeze(-1) 213 | 214 | def get_q_value_target_cont(self, o, a): 215 | with torch.no_grad(): 216 | q1_pi, q2_pi = self.ac_targ.q1q2(o, a) 217 | q_pi = torch.min(q1_pi, q2_pi) 218 | return q_pi.squeeze(-1), q1_pi.squeeze(-1), q2_pi.squeeze(-1) 219 | 220 | def sync_target(self): 221 | with torch.no_grad(): 222 | for p, p_targ in zip(self.ac.q1q2.parameters(), self.ac_targ.q1q2.parameters()): 223 | p_targ.data.mul_(self.polyak) 224 | p_targ.data.add_((1 - self.polyak) * p.data) 225 | for p, p_targ in zip(self.ac.pi.parameters(), self.ac_targ.pi.parameters()): 226 | p_targ.data.mul_(self.polyak) 227 | p_targ.data.add_((1 - self.polyak) * p.data) 228 | 229 | def save(self): 230 | parameters_dir = self.parameters_dir 231 | path = os.path.join(parameters_dir, "actor_net") 232 | torch.save(self.ac.pi.state_dict(), path) 233 | 234 | path = os.path.join(parameters_dir, "critic_net") 235 | torch.save(self.ac.q1q2.state_dict(), path) 236 | 237 | path = os.path.join(parameters_dir, "vs_net") 238 | torch.save(self.value_net.state_dict(), path) 239 | 240 | 241 | 242 | -------------------------------------------------------------------------------- /core/environment/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwang-ua/inac_pytorch/ca5007bbd59cf53adf0cc588dc5130b836c30622/core/environment/__init__.py -------------------------------------------------------------------------------- /core/environment/acrobot.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | import gym 5 | import copy 6 | 7 | import core.utils.helpers 8 | from core.utils.torch_utils import random_seed 9 | 10 | 11 | class Acrobot: 12 | def __init__(self, seed=np.random.randint(int(1e5))): 13 | random_seed(seed) 14 | self.state_dim = (6,) 15 | self.action_dim = 3 16 | self.env = gym.make('Acrobot-v1') 17 | self.env._seed = seed 18 | self.env._max_episode_steps = np.inf # control timeout setting in agent 19 | self.state = None 20 | 21 | def generate_state(self, coords): 22 | return coords 23 | 24 | def reset(self): 25 | self.state = np.asarray(self.env.reset()) 26 | return self.state 27 | 28 | def step(self, a): 29 | state, reward, done, info = self.env.step(a[0]) 30 | self.state = state 31 | # self.env.render() 32 | return np.asarray(state), np.asarray(reward), np.asarray(done), info 33 | 34 | def get_visualization_segment(self): 35 | raise NotImplementedError 36 | 37 | def get_useful(self, state=None): 38 | if state: 39 | return state 40 | else: 41 | return np.array(self.env.state) 42 | 43 | def info(self, key): 44 | return 45 | 46 | -------------------------------------------------------------------------------- /core/environment/ant.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['D4RL_SUPPRESS_IMPORT_ERROR'] = '1' 3 | 4 | import gym 5 | import d4rl 6 | import numpy as np 7 | 8 | from core.utils.torch_utils import random_seed 9 | 10 | 11 | class Ant: 12 | def __init__(self, seed=np.random.randint(int(1e5))): 13 | random_seed(seed) 14 | self.state_dim = (111,) 15 | self.action_dim = 8 16 | # self.env = gym.make('Ant-v2') 17 | self.env = gym.make('ant-random-v2')# Loading d4rl env. For the convinience of getting normalized score from d4rl 18 | self.env.unwrapped.seed(seed) 19 | self.env._max_episode_steps = np.inf # control timeout setting in agent 20 | self.state = None 21 | 22 | def reset(self): 23 | return self.env.reset() 24 | 25 | def step(self, a): 26 | ret = self.env.step(a[0]) 27 | state, reward, done, info = ret 28 | self.state = state 29 | # self.env.render() 30 | return np.asarray(state), np.asarray(reward), np.asarray(done), info 31 | 32 | def get_visualization_segment(self): 33 | raise NotImplementedError 34 | 35 | def get_useful(self, state=None): 36 | if state: 37 | return state 38 | else: 39 | return np.array(self.env.state) 40 | 41 | def info(self, key): 42 | return 43 | -------------------------------------------------------------------------------- /core/environment/env_factory.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from core.environment.mountaincar import MountainCar 4 | from core.environment.acrobot import Acrobot 5 | from core.environment.lunarlander import LunarLander 6 | from core.environment.halfcheetah import HalfCheetah 7 | from core.environment.walker2d import Walker2d 8 | from core.environment.hopper import Hopper 9 | from core.environment.ant import Ant 10 | 11 | class EnvFactory: 12 | @classmethod 13 | def create_env_fn(cls, cfg): 14 | if cfg.env_name == 'MountainCar': 15 | return lambda: MountainCar(cfg.seed) 16 | elif cfg.env_name == 'Acrobot': 17 | return lambda: Acrobot(cfg.seed) 18 | elif cfg.env_name == 'LunarLander': 19 | return lambda: LunarLander(cfg.seed) 20 | elif cfg.env_name == 'HalfCheetah': 21 | return lambda: HalfCheetah(cfg.seed) 22 | elif cfg.env_name == 'Walker2d': 23 | return lambda: Walker2d(cfg.seed) 24 | elif cfg.env_name == 'Hopper': 25 | return lambda: Hopper(cfg.seed) 26 | elif cfg.env_name == 'Ant': 27 | return lambda: Ant(cfg.seed) 28 | else: 29 | print(cfg.env_name) 30 | raise NotImplementedError -------------------------------------------------------------------------------- /core/environment/halfcheetah.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['D4RL_SUPPRESS_IMPORT_ERROR'] = '1' 3 | 4 | import gym 5 | import d4rl 6 | import numpy as np 7 | 8 | from core.utils.torch_utils import random_seed 9 | 10 | 11 | class HalfCheetah: 12 | def __init__(self, seed=np.random.randint(int(1e5))): 13 | random_seed(seed) 14 | self.state_dim = (17,) 15 | self.action_dim = 6 16 | # self.env = gym.make('HalfCheetah-v2') 17 | self.env = gym.make('halfcheetah-random-v2') # Loading d4rl env. For the convinience of getting normalized score from d4rl 18 | self.env.unwrapped.seed(seed) 19 | self.env._max_episode_steps = np.inf # control timeout setting in agent 20 | self.state = None 21 | 22 | def reset(self): 23 | return self.env.reset() 24 | 25 | def step(self, a): 26 | ret = self.env.step(a[0]) 27 | state, reward, done, info = ret 28 | self.state = state 29 | # self.env.render() 30 | return np.asarray(state), np.asarray(reward), np.asarray(done), info 31 | 32 | def get_visualization_segment(self): 33 | raise NotImplementedError 34 | 35 | def get_useful(self, state=None): 36 | if state: 37 | return state 38 | else: 39 | return np.array(self.env.state) 40 | 41 | def info(self, key): 42 | return 43 | -------------------------------------------------------------------------------- /core/environment/hopper.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['D4RL_SUPPRESS_IMPORT_ERROR'] = '1' 3 | 4 | import gym 5 | import d4rl 6 | import numpy as np 7 | 8 | from core.utils.torch_utils import random_seed 9 | 10 | 11 | class Hopper: 12 | def __init__(self, seed=np.random.randint(int(1e5))): 13 | random_seed(seed) 14 | self.state_dim = (11,) 15 | self.action_dim = 3 16 | # self.env = gym.make('Hopper-v2') 17 | self.env = gym.make('hopper-random-v2') # Loading d4rl env. For the convinience of getting normalized score from d4rl 18 | self.env.unwrapped.seed(seed) 19 | self.env._max_episode_steps = np.inf # control timeout setting in agent 20 | self.state = None 21 | 22 | def reset(self): 23 | return self.env.reset() 24 | 25 | def step(self, a): 26 | ret = self.env.step(a[0]) 27 | state, reward, done, info = ret 28 | self.state = state 29 | # self.env.env.render() 30 | return np.asarray(state), np.asarray(reward), np.asarray(done), info 31 | 32 | def get_visualization_segment(self): 33 | raise NotImplementedError 34 | 35 | def get_useful(self, state=None): 36 | if state: 37 | return state 38 | else: 39 | return np.array(self.env.state) 40 | 41 | def info(self, key): 42 | return 43 | -------------------------------------------------------------------------------- /core/environment/lunarlander.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import gym 3 | import copy 4 | 5 | from core.utils.torch_utils import random_seed 6 | 7 | 8 | class LunarLander: 9 | def __init__(self, seed=np.random.randint(int(1e5))): 10 | random_seed(seed) 11 | self.state_dim = (8,) 12 | self.action_dim = 4 13 | self.env = gym.make('LunarLander-v2') 14 | self.env._seed = seed 15 | self.env._max_episode_steps = np.inf # control timeout setting in agent 16 | 17 | def generate_state(self, coords): 18 | return coords 19 | 20 | def reset(self): 21 | return np.asarray(self.env.reset()) 22 | 23 | def step(self, a): 24 | state, reward, done, info = self.env.step(a[0]) 25 | # self.env.render() 26 | return np.asarray(state), np.asarray(reward), np.asarray(done), info 27 | 28 | def get_visualization_segment(self): 29 | raise NotImplementedError 30 | 31 | def get_useful(self, state=None): 32 | if state: 33 | return state 34 | else: 35 | return np.array(self.env.state) 36 | 37 | def info(self, key): 38 | return 39 | -------------------------------------------------------------------------------- /core/environment/mountaincar.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import gym 3 | import copy 4 | 5 | from core.utils.torch_utils import random_seed 6 | 7 | 8 | class MountainCar: 9 | def __init__(self, seed=np.random.randint(int(1e5))): 10 | random_seed(seed) 11 | self.state_dim = (2,) 12 | self.action_dim = 3 13 | self.env = gym.make('MountainCar-v0') 14 | self.env._seed = seed 15 | self.env._max_episode_steps = np.inf # control timeout setting in agent 16 | 17 | def generate_state(self, coords): 18 | return coords 19 | 20 | def reset(self): 21 | return np.asarray(self.env.reset()) 22 | 23 | def step(self, a): 24 | state, reward, done, info = self.env.step(a[0]) 25 | # self.env.render() 26 | return np.asarray(state), np.asarray(reward), np.asarray(done), info 27 | 28 | def get_visualization_segment(self): 29 | raise NotImplementedError 30 | 31 | def get_useful(self, state=None): 32 | if state: 33 | return state 34 | else: 35 | return np.array(self.env.state) 36 | 37 | def info(self, key): 38 | return 39 | 40 | -------------------------------------------------------------------------------- /core/environment/walker2d.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['D4RL_SUPPRESS_IMPORT_ERROR'] = '1' 3 | 4 | import gym 5 | import d4rl 6 | import numpy as np 7 | 8 | from core.utils.torch_utils import random_seed 9 | 10 | 11 | class Walker2d: 12 | def __init__(self, seed=np.random.randint(int(1e5))): 13 | random_seed(seed) 14 | self.state_dim = (17,) 15 | self.action_dim = 6 16 | # self.env = gym.make('Walker2d-v2') 17 | self.env = gym.make('walker2d-random-v2')# Loading d4rl env. For the convinience of getting normalized score from d4rl 18 | self.env.unwrapped.seed(seed) 19 | self.env._max_episode_steps = np.inf # control timeout setting in agent 20 | self.state = None 21 | 22 | def reset(self): 23 | return self.env.reset() 24 | 25 | def step(self, a): 26 | ret = self.env.step(a[0]) 27 | state, reward, done, info = ret 28 | self.state = state 29 | # self.env.render() 30 | return np.asarray(state), np.asarray(reward), np.asarray(done), info 31 | 32 | def get_visualization_segment(self): 33 | raise NotImplementedError 34 | 35 | def get_useful(self, state=None): 36 | if state: 37 | return state 38 | else: 39 | return np.array(self.env.state) 40 | 41 | def info(self, key): 42 | return 43 | -------------------------------------------------------------------------------- /core/network/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwang-ua/inac_pytorch/ca5007bbd59cf53adf0cc588dc5130b836c30622/core/network/__init__.py -------------------------------------------------------------------------------- /core/network/network_architectures.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as functional 5 | 6 | from core.network import network_utils, network_bodies 7 | from core.utils import torch_utils 8 | 9 | 10 | class FCNetwork(nn.Module): 11 | def __init__(self, device, input_units, hidden_units, output_units, head_activation=lambda x:x): 12 | super().__init__() 13 | body = network_bodies.FCBody(device, input_units, hidden_units=tuple(hidden_units), init_type='xavier') 14 | self.body = body 15 | self.fc_head = network_utils.layer_init_xavier(nn.Linear(body.feature_dim, output_units, bias=True), bias=True) 16 | self.device = device 17 | self.head_activation = head_activation 18 | self.to(device) 19 | 20 | def forward(self, x): 21 | if not isinstance(x, torch.Tensor): x = torch_utils.tensor(x, self.device) 22 | if len(x.shape) > 2: x = x.view(x.shape[0], -1) 23 | y = self.body(x) 24 | y = self.fc_head(y) 25 | y = self.head_activation(y) 26 | return y 27 | 28 | class DoubleCriticDiscrete(nn.Module): 29 | def __init__(self, device, input_units, hidden_units, output_units): 30 | super().__init__() 31 | self.device = device 32 | self.q1_net = FCNetwork(device, input_units, hidden_units, output_units) 33 | self.q2_net = FCNetwork(device, input_units, hidden_units, output_units) 34 | 35 | # def forward(self, x, a): 36 | def forward(self, x): 37 | if not isinstance(x, torch.Tensor): x = torch_utils.tensor(x, self.device) 38 | recover_size = False 39 | if len(x.size()) == 1: 40 | recover_size = True 41 | x = x.reshape((1, -1)) 42 | q1 = self.q1_net(x) 43 | q2 = self.q2_net(x) 44 | if recover_size: 45 | q1 = q1[0] 46 | q2 = q2[0] 47 | return q1, q2 48 | 49 | 50 | class DoubleCriticNetwork(nn.Module): 51 | def __init__(self, device, num_inputs, num_actions, hidden_units): 52 | super(DoubleCriticNetwork, self).__init__() 53 | self.device = device 54 | 55 | # Q1 architecture 56 | self.body1 = network_bodies.FCBody(device, num_inputs + num_actions, hidden_units=tuple(hidden_units)) 57 | self.head1 = network_utils.layer_init_xavier(nn.Linear(self.body1.feature_dim, 1)) 58 | # Q2 architecture 59 | self.body2 = network_bodies.FCBody(device, num_inputs + num_actions, hidden_units=tuple(hidden_units)) 60 | self.head2 = network_utils.layer_init_xavier(nn.Linear(self.body2.feature_dim, 1)) 61 | 62 | def forward(self, state, action): 63 | if not isinstance(state, torch.Tensor): state = torch_utils.tensor(state, self.device) 64 | recover_size = False 65 | if len(state.shape) > 2: 66 | state = state.view(state.shape[0], -1) 67 | action = action.view(action.shape[0], -1) 68 | elif len(state.shape) == 1: 69 | state = state.view(1, -1) 70 | action = action.view(1, -1) 71 | recover_size = True 72 | if not isinstance(action, torch.Tensor): action = torch_utils.tensor(action, self.device) 73 | 74 | xu = torch.cat([state, action], 1) 75 | 76 | q1 = self.head1(self.body1(xu)) 77 | q2 = self.head2(self.body2(xu)) 78 | 79 | if recover_size: 80 | q1 = q1[0] 81 | q2 = q2[0] 82 | return q1, q2 83 | 84 | -------------------------------------------------------------------------------- /core/network/network_bodies.py: -------------------------------------------------------------------------------- 1 | from functools import reduce 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as functional 6 | 7 | from core.network import network_utils 8 | 9 | class FCBody(nn.Module): 10 | def __init__(self, device, input_dim, hidden_units=(64, 64), activation=functional.relu, init_type='xavier', info=None): 11 | super().__init__() 12 | self.to(device) 13 | self.device = device 14 | dims = (input_dim,) + hidden_units 15 | self.layers = nn.ModuleList([network_utils.layer_init_xavier(nn.Linear(dim_in, dim_out).to(device)) for dim_in, dim_out in zip(dims[:-1], dims[1:])]) 16 | 17 | if init_type == "xavier": 18 | self.layers = nn.ModuleList([network_utils.layer_init_xavier(nn.Linear(dim_in, dim_out).to(device)) for dim_in, dim_out in zip(dims[:-1], dims[1:])]) 19 | elif init_type == "uniform": 20 | self.layers = nn.ModuleList([network_utils.layer_init_uniform(nn.Linear(dim_in, dim_out).to(device)) for dim_in, dim_out in zip(dims[:-1], dims[1:])]) 21 | elif init_type == "zeros": 22 | self.layers = nn.ModuleList([network_utils.layer_init_zero(nn.Linear(dim_in, dim_out).to(device)) for dim_in, dim_out in zip(dims[:-1], dims[1:])]) 23 | elif init_type == "constant": 24 | self.layers = nn.ModuleList([network_utils.layer_init_constant(nn.Linear(dim_in, dim_out).to(device), const=info) for dim_in, dim_out in zip(dims[:-1], dims[1:])]) 25 | else: 26 | raise ValueError('init_type is not defined: {}'.format(init_type)) 27 | 28 | self.activation = activation 29 | self.feature_dim = dims[-1] 30 | 31 | def forward(self, x): 32 | for layer in self.layers: 33 | x = self.activation(layer(x)) 34 | return x 35 | 36 | def compute_lipschitz_upper(self): 37 | return [np.linalg.norm(layer.weight.detach().cpu().numpy(), ord=2) for layer in self.layers] 38 | 39 | 40 | class ConvBody(nn.Module): 41 | def __init__(self, device, state_dim, architecture): 42 | super().__init__() 43 | 44 | def size(size, kernel_size=3, stride=1, padding=0): 45 | return (size + 2 * padding - (kernel_size - 1) - 1) // stride + 1 46 | 47 | spatial_length, _, in_channels = state_dim 48 | num_units = None 49 | layers = nn.ModuleList() 50 | for layer_cfg in architecture['conv_layers']: 51 | layers.append(nn.Conv2d(layer_cfg["in"], layer_cfg["out"], layer_cfg["kernel"], 52 | layer_cfg["stride"], layer_cfg["pad"])) 53 | if not num_units: 54 | num_units = size(spatial_length, layer_cfg["kernel"], layer_cfg["stride"], layer_cfg["pad"]) 55 | else: 56 | num_units = size(num_units, layer_cfg["kernel"], layer_cfg["stride"], layer_cfg["pad"]) 57 | num_units = num_units ** 2 * architecture["conv_layers"][-1]["out"] 58 | 59 | self.feature_dim = num_units 60 | self.spatial_length = spatial_length 61 | self.in_channels = in_channels 62 | self.layers = layers 63 | self.to(device) 64 | self.device = device 65 | 66 | def forward(self, x): 67 | x = functional.relu(self.layers[0](self.shape_image(x))) 68 | for idx, layer in enumerate(self.layers[1:]): 69 | x = functional.relu(layer(x)) 70 | # return x.view(x.size(0), -1) 71 | return x.reshape(x.size(0), -1) 72 | -------------------------------------------------------------------------------- /core/network/network_utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | def layer_init(layer, w_scale=1.0): 5 | nn.init.orthogonal_(layer.weight.data) 6 | layer.weight.data.mul_(w_scale) 7 | nn.init.constant_(layer.bias.data, 0) 8 | return layer 9 | 10 | 11 | def layer_init_zero(layer, bias=True): 12 | nn.init.constant_(layer.weight, 0) 13 | if bias: 14 | nn.init.constant_(layer.bias.data, 0) 15 | return layer 16 | 17 | def layer_init_constant(layer, const, bias=True): 18 | nn.init.constant_(layer.weight, const) 19 | if bias: 20 | nn.init.constant_(layer.bias.data, const) 21 | return layer 22 | 23 | 24 | def layer_init_xavier(layer, bias=True): 25 | nn.init.xavier_uniform_(layer.weight) 26 | if bias: 27 | nn.init.constant_(layer.bias.data, 0) 28 | return layer 29 | 30 | def layer_init_uniform(layer, low=-0.003, high=0.003, bias=0): 31 | nn.init.uniform_(layer.weight, low, high) 32 | if not (type(bias)==bool and bias==False): 33 | nn.init.constant_(layer.bias.data, bias) 34 | return layer 35 | -------------------------------------------------------------------------------- /core/network/policy_factory.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.distributions import Normal 6 | from torch.distributions import Categorical 7 | 8 | from core.network import network_utils, network_bodies 9 | from core.utils import torch_utils 10 | 11 | 12 | class MLPCont(nn.Module): 13 | def __init__(self, device, obs_dim, act_dim, hidden_sizes, action_range=1.0, init_type='xavier'): 14 | super().__init__() 15 | self.device = device 16 | body = network_bodies.FCBody(device, obs_dim, hidden_units=tuple(hidden_sizes), init_type=init_type) 17 | body_out = obs_dim if hidden_sizes==[] else hidden_sizes[-1] 18 | self.body = body 19 | self.mu_layer = network_utils.layer_init_xavier(nn.Linear(body_out, act_dim)) 20 | self.log_std_logits = nn.Parameter(torch.zeros(act_dim, requires_grad=True)) 21 | self.min_log_std = -6 22 | self.max_log_std = 0 23 | self.action_range = action_range 24 | 25 | """https://github.com/hari-sikchi/AWAC/blob/3ad931ec73101798ffe82c62b19313a8607e4f1e/core.py#L91""" 26 | def forward(self, obs, deterministic=False): 27 | if not isinstance(obs, torch.Tensor): obs = torch_utils.tensor(obs, self.device) 28 | recover_size = False 29 | if len(obs.size()) == 1: 30 | recover_size = True 31 | obs = obs.reshape((1, -1)) 32 | net_out = self.body(obs) 33 | mu = self.mu_layer(net_out) 34 | mu = torch.tanh(mu) * self.action_range 35 | 36 | log_std = torch.sigmoid(self.log_std_logits) 37 | log_std = self.min_log_std + log_std * (self.max_log_std - self.min_log_std) 38 | std = torch.exp(log_std) 39 | pi_distribution = Normal(mu, std) 40 | if deterministic: 41 | pi_action = mu 42 | else: 43 | pi_action = pi_distribution.rsample() 44 | logp_pi = pi_distribution.log_prob(pi_action).sum(axis=-1) 45 | 46 | if recover_size: 47 | pi_action, logp_pi = pi_action[0], logp_pi[0] 48 | return pi_action, logp_pi 49 | 50 | def get_logprob(self, obs, actions): 51 | if not isinstance(obs, torch.Tensor): obs = torch_utils.tensor(obs, self.device) 52 | if not isinstance(actions, torch.Tensor): actions = torch_utils.tensor(actions, self.device) 53 | net_out = self.body(obs) 54 | mu = self.mu_layer(net_out) 55 | mu = torch.tanh(mu) * self.action_range 56 | log_std = torch.sigmoid(self.log_std_logits) 57 | # log_std = self.log_std_layer(net_out) 58 | log_std = self.min_log_std + log_std * ( 59 | self.max_log_std - self.min_log_std) 60 | std = torch.exp(log_std) 61 | pi_distribution = Normal(mu, std) 62 | logp_pi = pi_distribution.log_prob(actions).sum(axis=-1) 63 | return logp_pi 64 | 65 | 66 | class MLPDiscrete(nn.Module): 67 | def __init__(self, device, obs_dim, act_dim, hidden_sizes, init_type='xavier'): 68 | super().__init__() 69 | self.device = device 70 | body = network_bodies.FCBody(device, obs_dim, hidden_units=tuple(hidden_sizes), init_type=init_type) 71 | body_out = obs_dim if hidden_sizes==[] else hidden_sizes[-1] 72 | self.body = body 73 | self.mu_layer = network_utils.layer_init_xavier(nn.Linear(body_out, act_dim)) 74 | self.log_std_logits = nn.Parameter(torch.zeros(act_dim, requires_grad=True)) 75 | self.min_log_std = -6 76 | self.max_log_std = 0 77 | 78 | def forward(self, obs, deterministic=True): 79 | if not isinstance(obs, torch.Tensor): obs = torch_utils.tensor(obs, self.device) 80 | recover_size = False 81 | if len(obs.size()) == 1: 82 | recover_size = True 83 | obs = obs.reshape((1, -1)) 84 | net_out = self.body(obs) 85 | probs = self.mu_layer(net_out) 86 | probs = F.softmax(probs, dim=1) 87 | m = Categorical(probs) 88 | action = m.sample() 89 | logp = m.log_prob(action) 90 | if recover_size: 91 | action, logp = action[0], logp[0] 92 | return action, logp 93 | 94 | def get_logprob(self, obs, actions): 95 | if not isinstance(obs, torch.Tensor): obs = torch_utils.tensor(obs, self.device) 96 | if not isinstance(actions, torch.Tensor): actions = torch_utils.tensor(actions, self.device) 97 | net_out = self.body(obs) 98 | probs = self.mu_layer(net_out) 99 | probs = F.softmax(probs, dim=1) 100 | m = Categorical(probs) 101 | logp_pi = m.log_prob(actions) 102 | return logp_pi 103 | -------------------------------------------------------------------------------- /core/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwang-ua/inac_pytorch/ca5007bbd59cf53adf0cc588dc5130b836c30622/core/utils/__init__.py -------------------------------------------------------------------------------- /core/utils/helpers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | def common_member(a, b): 5 | a_set = set(a) 6 | b_set = set(b) 7 | if (a_set & b_set): 8 | return True 9 | else: 10 | return False 11 | 12 | def arcradians(cos, sin): 13 | if cos > 0 and sin > 0: 14 | return np.arccos(cos) 15 | elif cos > 0 and sin < 0: 16 | return np.arcsin(sin) 17 | elif cos < 0 and sin > 0: 18 | return np.arccos(cos) 19 | elif cos < 0 and sin < 0: 20 | return -1 * np.arccos(cos) 21 | 22 | 23 | def normalize_rows(x): 24 | return x / np.linalg.norm(x, ord=2, axis=1, keepdims=True) 25 | 26 | def copy_row(x, num_rows): 27 | return np.multiply(np.ones((num_rows, 1)), x) 28 | 29 | def expectile_loss(diff, expectile=0.8): 30 | weight = torch.where(diff > 0, expectile, (1 - expectile)) 31 | return weight * (diff ** 2) 32 | 33 | def search_same_row(matrix, target_row): 34 | idx = np.where(np.all(matrix == target_row, axis=1)) 35 | return idx -------------------------------------------------------------------------------- /core/utils/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import logging 5 | 6 | # from tensorboardX import SummaryWriter 7 | 8 | def log_config(cfg): 9 | def get_print_attrs(cfg): 10 | attrs = dict(cfg.__dict__) 11 | for k in ['logger', 'env_fn', 'offline_data']: 12 | del attrs[k] 13 | return attrs 14 | attrs = get_print_attrs(cfg) 15 | for param, value in attrs.items(): 16 | cfg.logger.info('{}: {}'.format(param, value)) 17 | 18 | 19 | class Logger: 20 | def __init__(self, config, log_dir): 21 | log_file = os.path.join(log_dir, 'log') 22 | self._logger = logging.getLogger() 23 | 24 | file_handler = logging.FileHandler(log_file, mode='w') 25 | formatter = logging.Formatter('%(asctime)s | %(message)s') 26 | file_handler.setFormatter(formatter) 27 | self._logger.addHandler(file_handler) 28 | 29 | stream_handler = logging.StreamHandler(sys.stdout) 30 | stream_handler.setFormatter(formatter) 31 | self._logger.addHandler(stream_handler) 32 | 33 | self._logger.setLevel(level=logging.INFO) 34 | 35 | self.config = config 36 | # if config.tensorboard_logs: self.tensorboard_writer = SummaryWriter(config.get_log_dir()) 37 | 38 | def info(self, log_msg): 39 | self._logger.info(log_msg) -------------------------------------------------------------------------------- /core/utils/run_funcs.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import time 3 | import copy 4 | import numpy as np 5 | 6 | import os 7 | os.environ['D4RL_SUPPRESS_IMPORT_ERROR'] = '1' 8 | import gym 9 | import d4rl 10 | import gzip 11 | 12 | EARLYCUTOFF = "EarlyCutOff" 13 | 14 | 15 | def load_testset(env_name, dataset, id): 16 | path = None 17 | if env_name == 'HalfCheetah': 18 | if dataset == 'expert': 19 | path = {"env": "halfcheetah-expert-v2"} 20 | elif dataset == 'medexp': 21 | path = {"env": "halfcheetah-medium-expert-v2"} 22 | elif dataset == 'medium': 23 | path = {"env": "halfcheetah-medium-v2"} 24 | elif dataset == 'medrep': 25 | path = {"env": "halfcheetah-medium-replay-v2"} 26 | elif env_name == 'Walker2d': 27 | if dataset == 'expert': 28 | path = {"env": "walker2d-expert-v2"} 29 | elif dataset == 'medexp': 30 | path = {"env": "walker2d-medium-expert-v2"} 31 | elif dataset == 'medium': 32 | path = {"env": "walker2d-medium-v2"} 33 | elif dataset == 'medrep': 34 | path = {"env": "walker2d-medium-replay-v2"} 35 | elif env_name == 'Hopper': 36 | if dataset == 'expert': 37 | path = {"env": "hopper-expert-v2"} 38 | elif dataset == 'medexp': 39 | path = {"env": "hopper-medium-expert-v2"} 40 | elif dataset == 'medium': 41 | path = {"env": "hopper-medium-v2"} 42 | elif dataset == 'medrep': 43 | path = {"env": "hopper-medium-replay-v2"} 44 | elif env_name == 'Ant': 45 | if dataset == 'expert': 46 | path = {"env": "ant-expert-v2"} 47 | elif dataset == 'medexp': 48 | path = {"env": "ant-medium-expert-v2"} 49 | elif dataset == 'medium': 50 | path = {"env": "ant-medium-v2"} 51 | elif dataset == 'medrep': 52 | path = {"env": "ant-medium-replay-v2"} 53 | 54 | elif env_name == 'Acrobot': 55 | if dataset == 'expert': 56 | path = {"pkl": "data/dataset/acrobot/transitions_50k/train_40k/{}_run.pkl".format(id)} 57 | elif dataset == 'mixed': 58 | path = {"pkl": "data/dataset/acrobot/transitions_50k/train_mixed/{}_run.pkl".format(id)} 59 | elif env_name == 'LunarLander': 60 | if dataset == 'expert': 61 | path = {"pkl": "data/dataset/lunar_lander/transitions_50k/train_500k/{}_run.pkl".format(id)} 62 | elif dataset == 'mixed': 63 | path = {"pkl": "data/dataset/lunar_lander/transitions_50k/train_mixed/{}_run.pkl".format(id)} 64 | elif env_name == 'MountainCar': 65 | if dataset == 'expert': 66 | path = {"pkl": "data/dataset/mountain_car/transitions_50k/train_60k/{}_run.pkl".format(id)} 67 | elif dataset == 'mixed': 68 | path = {"pkl": "data/dataset/mountain_car/transitions_50k/train_mixed/{}_run.pkl".format(id)} 69 | 70 | assert path is not None 71 | testsets = {} 72 | for name in path: 73 | if name == "env": 74 | env = gym.make(path['env']) 75 | try: 76 | data = env.get_dataset() 77 | except: 78 | env = env.unwrapped 79 | data = env.get_dataset() 80 | testsets[name] = { 81 | 'states': data['observations'], 82 | 'actions': data['actions'], 83 | 'rewards': data['rewards'], 84 | 'next_states': data['next_observations'], 85 | 'terminations': data['terminals'], 86 | } 87 | else: 88 | pth = path[name] 89 | with open(pth.format(id), 'rb') as f: 90 | testsets[name] = pickle.load(f) 91 | 92 | return testsets 93 | else: 94 | return {} 95 | 96 | def run_steps(agent, max_steps, log_interval, eval_pth): 97 | t0 = time.time() 98 | evaluations = [] 99 | agent.populate_returns(initialize=True) 100 | while True: 101 | if log_interval and not agent.total_steps % log_interval: 102 | mean, median, min_, max_ = agent.log_file(elapsed_time=log_interval / (time.time() - t0), test=True) 103 | evaluations.append(mean) 104 | t0 = time.time() 105 | if max_steps and agent.total_steps >= max_steps: 106 | break 107 | agent.step() 108 | agent.save() 109 | np.save(eval_pth+"/evaluations.npy", np.array(evaluations)) -------------------------------------------------------------------------------- /core/utils/torch_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | 5 | 6 | def tensor(x, device): 7 | if isinstance(x, torch.Tensor): 8 | return x 9 | x = torch.tensor(x, dtype=torch.float32).to(device) 10 | return x 11 | 12 | def to_np(t): 13 | return t.cpu().detach().numpy() 14 | 15 | def random_seed(seed): 16 | np.random.seed(seed) 17 | torch.manual_seed(seed) 18 | 19 | def set_one_thread(): 20 | os.environ['OMP_NUM_THREADS'] = '1' 21 | os.environ['MKL_NUM_THREADS'] = '1' 22 | torch.set_num_threads(1) 23 | 24 | def ensure_dir(d): 25 | if not os.path.exists(d): 26 | os.makedirs(d) 27 | -------------------------------------------------------------------------------- /img/after_fix.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hwang-ua/inac_pytorch/ca5007bbd59cf53adf0cc588dc5130b836c30622/img/after_fix.png -------------------------------------------------------------------------------- /run_ac_offline.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | import core.environment.env_factory as environment 5 | from core.utils import torch_utils, logger, run_funcs 6 | from core.agent.in_sample import * 7 | 8 | 9 | if __name__ == '__main__': 10 | parser = argparse.ArgumentParser(description="run_file") 11 | parser.add_argument('--seed', default=0, type=int) 12 | parser.add_argument('--env_name', default='Ant', type=str) 13 | parser.add_argument('--dataset', default='medexp', type=str) 14 | parser.add_argument('--discrete_control', default=0, type=int) 15 | parser.add_argument('--state_dim', default=1, type=int) 16 | parser.add_argument('--action_dim', default=1, type=int) 17 | parser.add_argument('--tau', default=0.1, type=float) 18 | 19 | parser.add_argument('--max_steps', default=1000000, type=int) 20 | parser.add_argument('--log_interval', default=10000, type=int) 21 | parser.add_argument('--learning_rate', default=3e-4, type=float) 22 | parser.add_argument('--hidden_units', default=256, type=int) 23 | parser.add_argument('--batch_size', default=256, type=int) 24 | parser.add_argument('--timeout', default=1000, type=int) 25 | parser.add_argument('--gamma', default=0.99, type=float) 26 | parser.add_argument('--use_target_network', default=1, type=int) 27 | parser.add_argument('--target_network_update_freq', default=1, type=int) 28 | parser.add_argument('--polyak', default=0.995, type=float) 29 | parser.add_argument('--evaluation_criteria', default='return', type=str) 30 | parser.add_argument('--device', default='cpu', type=str) 31 | parser.add_argument('--info', default='0', type=str) 32 | cfg = parser.parse_args() 33 | 34 | torch_utils.set_one_thread() 35 | 36 | torch_utils.random_seed(cfg.seed) 37 | 38 | project_root = os.path.abspath(os.path.dirname(__file__)) 39 | exp_path = "data/output/{}/{}/{}/{}_run".format(cfg.env_name, cfg.dataset, cfg.info, cfg.seed) 40 | cfg.exp_path = os.path.join(project_root, exp_path) 41 | torch_utils.ensure_dir(cfg.exp_path) 42 | cfg.env_fn = environment.EnvFactory.create_env_fn(cfg) 43 | cfg.offline_data = run_funcs.load_testset(cfg.env_name, cfg.dataset, cfg.seed) 44 | 45 | # Setting up the logger 46 | cfg.logger = logger.Logger(cfg, cfg.exp_path) 47 | logger.log_config(cfg) 48 | 49 | # Initializing the agent and running the experiment 50 | agent_obj = InSampleAC( 51 | device=cfg.device, 52 | discrete_control=cfg.discrete_control, 53 | state_dim=cfg.state_dim, 54 | action_dim=cfg.action_dim, 55 | hidden_units=cfg.hidden_units, 56 | learning_rate=cfg.learning_rate, 57 | tau=cfg.tau, 58 | polyak=cfg.polyak, 59 | exp_path=cfg.exp_path, 60 | seed=cfg.seed, 61 | env_fn=cfg.env_fn, 62 | timeout=cfg.timeout, 63 | gamma=cfg.gamma, 64 | offline_data=cfg.offline_data, 65 | batch_size=cfg.batch_size, 66 | use_target_network=cfg.use_target_network, 67 | target_network_update_freq=cfg.target_network_update_freq, 68 | evaluation_criteria=cfg.evaluation_criteria, 69 | logger=cfg.logger 70 | ) 71 | run_funcs.run_steps(agent_obj, cfg.max_steps, cfg.log_interval, exp_path) --------------------------------------------------------------------------------