├── .gitignore
├── README.md
├── core
├── __init__.py
├── agent
│ ├── base.py
│ └── in_sample.py
├── environment
│ ├── __init__.py
│ ├── acrobot.py
│ ├── ant.py
│ ├── env_factory.py
│ ├── halfcheetah.py
│ ├── hopper.py
│ ├── lunarlander.py
│ ├── mountaincar.py
│ └── walker2d.py
├── network
│ ├── __init__.py
│ ├── network_architectures.py
│ ├── network_bodies.py
│ ├── network_utils.py
│ └── policy_factory.py
└── utils
│ ├── __init__.py
│ ├── helpers.py
│ ├── logger.py
│ ├── run_funcs.py
│ └── torch_utils.py
├── img
└── after_fix.png
└── run_ac_offline.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # Pycharm
121 | .idea
122 |
123 | # mkdocs documentation
124 | /site
125 |
126 | # mypy
127 | .mypy_cache/
128 | .dmypy.json
129 | dmypy.json
130 |
131 | # Pyre type checker
132 | .pyre/
133 |
134 | # Data
135 | data/
136 | output/
137 | plot/img
138 |
139 | # Cache
140 | *__pycache__*
141 | *.pyc
142 |
143 | #CMD
144 | cmd*.sh
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | This is a code release for our paper 'The In-Sample Softmax for Offline Reinforcement Learning' (https://openreview.net/pdf?id=u-RuvyDYqCM).
2 |
3 | # Running the code:
4 |
5 | ```
6 | python run_ac_offline.py --seed 0 --env_name Ant --dataset expert --discrete_control 0 --state_dim 111 --action_dim 8 --tau 0.01 --learning_rate 0.0003 --hidden_units 256 --batch_size 256 --timeout 1000 --max_steps 1000000 --log_interval 10000
7 |
8 | python run_ac_offline.py --seed 0 --env_name Ant --dataset medexp --discrete_control 0 --state_dim 111 --action_dim 8 --tau 0.01 --learning_rate 0.0003 --hidden_units 256 --batch_size 256 --timeout 1000 --max_steps 1000000 --log_interval 10000
9 |
10 | python run_ac_offline.py --seed 0 --env_name Ant --dataset medium --discrete_control 0 --state_dim 111 --action_dim 8 --tau 0.5 --learning_rate 0.0003 --hidden_units 256 --batch_size 256 --timeout 1000 --max_steps 1000000 --log_interval 10000
11 |
12 | python run_ac_offline.py --seed 0 --env_name Ant --dataset medrep --discrete_control 0 --state_dim 111 --action_dim 8 --tau 0.5 --learning_rate 0.0003 --hidden_units 256 --batch_size 256 --timeout 1000 --max_steps 1000000 --log_interval 10000
13 |
14 | python run_ac_offline.py --seed 0 --env_name HalfCheetah --dataset expert --discrete_control 0 --state_dim 17 --action_dim 6 --tau 0.01 --learning_rate 0.0003 --hidden_units 256 --batch_size 256 --timeout 1000 --max_steps 1000000 --log_interval 10000
15 |
16 | python run_ac_offline.py --seed 0 --env_name HalfCheetah --dataset medexp --discrete_control 0 --state_dim 17 --action_dim 6 --tau 0.1 --learning_rate 0.0003 --hidden_units 256 --batch_size 256 --timeout 1000 --max_steps 1000000 --log_interval 10000
17 |
18 | python run_ac_offline.py --seed 0 --env_name HalfCheetah --dataset medium --discrete_control 0 --state_dim 17 --action_dim 6 --tau 0.33 --learning_rate 0.0003 --hidden_units 256 --batch_size 256 --timeout 1000 --max_steps 1000000 --log_interval 10000
19 |
20 | python run_ac_offline.py --seed 0 --env_name HalfCheetah --dataset medrep --discrete_control 0 --state_dim 17 --action_dim 6 --tau 0.5 --learning_rate 0.0003 --hidden_units 256 --batch_size 256 --timeout 1000 --max_steps 1000000 --log_interval 10000
21 |
22 | python run_ac_offline.py --seed 0 --env_name Hopper --dataset expert --discrete_control 0 --state_dim 11 --action_dim 3 --tau 0.01 --learning_rate 0.0003 --hidden_units 256 --batch_size 256 --timeout 1000 --max_steps 1000000 --log_interval 10000
23 |
24 | python run_ac_offline.py --seed 0 --env_name Hopper --dataset medexp --discrete_control 0 --state_dim 11 --action_dim 3 --tau 0.01 --learning_rate 0.0003 --hidden_units 256 --batch_size 256 --timeout 1000 --max_steps 1000000 --log_interval 10000
25 |
26 | python run_ac_offline.py --seed 0 --env_name Hopper --dataset medium --discrete_control 0 --state_dim 11 --action_dim 3 --tau 0.1 --learning_rate 0.0003 --hidden_units 256 --batch_size 256 --timeout 1000 --max_steps 1000000 --log_interval 10000
27 |
28 | python run_ac_offline.py --seed 0 --env_name Hopper --dataset medrep --discrete_control 0 --state_dim 11 --action_dim 3 --tau 0.5 --learning_rate 0.0003 --hidden_units 256 --batch_size 256 --timeout 1000 --max_steps 1000000 --log_interval 10000
29 |
30 | python run_ac_offline.py --seed 0 --env_name Walker2d --dataset expert --discrete_control 0 --state_dim 17 --action_dim 6 --tau 0.01 --learning_rate 0.0003 --hidden_units 256 --batch_size 256 --timeout 1000 --max_steps 1000000 --log_interval 10000
31 |
32 | python run_ac_offline.py --seed 0 --env_name Walker2d --dataset medexp --discrete_control 0 --state_dim 17 --action_dim 6 --tau 0.1 --learning_rate 0.0003 --hidden_units 256 --batch_size 256 --timeout 1000 --max_steps 1000000 --log_interval 10000
33 |
34 | python run_ac_offline.py --seed 0 --env_name Walker2d --dataset medium --discrete_control 0 --state_dim 17 --action_dim 6 --tau 0.33 --learning_rate 0.0003 --hidden_units 256 --batch_size 256 --timeout 1000 --max_steps 1000000 --log_interval 10000
35 |
36 | python run_ac_offline.py --seed 0 --env_name Walker2d --dataset medrep --discrete_control 0 --state_dim 17 --action_dim 6 --tau 0.5 --learning_rate 0.0003 --hidden_units 256 --batch_size 256 --timeout 1000 --max_steps 1000000 --log_interval 10000
37 | ```
38 |
39 | **Update:**
40 |
41 | We fixed the policy network for continuous control (Thanks for @typoverflow!). We rerun the affected baselines with 5 runs. The hyperparameters have been updated above, and the results are reported below.
42 | The fix **did not** change the **overall performance** and the **conclusions** reported in the paper.
43 |
44 |
45 |
46 | # D4RL installation
47 | If you are using *Ubuntu* and have not got *d4rl* installed yet, this section may help
48 |
49 | 1. Download mujoco
50 |
51 | I am using mujoco210. It can be downloaded from https://github.com/deepmind/mujoco/releases/download/2.1.0/mujoco210-linux-x86_64.tar.gz
52 | ```
53 | mkdir .mujoco
54 | mv mujoco210-linux-x86_64.tar.gz .mujoco
55 | cd .mujoco
56 | tar -xvzf mujoco210-linux-x86_64.tar.gz
57 | ```
58 |
59 | Then, add mujoco path:
60 |
61 | Open .bashrc file and add the following line:
62 | ```
63 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/.mujoco/mujoco210/bin
64 | ```
65 |
66 | Save the change and run the following command:
67 | ```
68 | source .bashrc
69 | ```
70 |
71 | 2. Install other packages and D4RL
72 | ```
73 | pip install mujoco_py
74 | pip install dm_control==1.0.7
75 | pip install git+https://github.com/Farama-Foundation/d4rl@master#egg=d4rl
76 | ```
77 |
78 | 3. Test the installation in python
79 | ```
80 | import gym
81 | import d4rl
82 | env = gym.make('maze2d-umaze-v1')
83 | env.get_dataset()
84 | ```
85 |
--------------------------------------------------------------------------------
/core/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hwang-ua/inac_pytorch/ca5007bbd59cf53adf0cc588dc5130b836c30622/core/__init__.py
--------------------------------------------------------------------------------
/core/agent/base.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 | import pickle
5 | import torch
6 | import copy
7 |
8 | from core.utils import torch_utils
9 |
10 |
11 | class Replay:
12 | def __init__(self, memory_size, batch_size, seed=0):
13 | self.rng = np.random.RandomState(seed)
14 | self.memory_size = memory_size
15 | self.batch_size = batch_size
16 | self.data = []
17 | self.pos = 0
18 |
19 | def feed(self, experience):
20 | if self.pos >= len(self.data):
21 | self.data.append(experience)
22 | else:
23 | self.data[self.pos] = experience
24 | self.pos = (self.pos + 1) % self.memory_size
25 |
26 | def feed_batch(self, experience):
27 | for exp in experience:
28 | self.feed(exp)
29 |
30 | def sample(self, batch_size=None):
31 | if batch_size is None:
32 | batch_size = self.batch_size
33 | sampled_indices = [self.rng.randint(0, len(self.data)) for _ in range(batch_size)]
34 | sampled_data = [self.data[ind] for ind in sampled_indices]
35 | batch_data = list(map(lambda x: np.asarray(x), zip(*sampled_data)))
36 |
37 | return batch_data
38 |
39 | def sample_array(self, batch_size=None):
40 | if batch_size is None:
41 | batch_size = self.batch_size
42 |
43 | sampled_indices = [self.rng.randint(0, len(self.data)) for _ in range(batch_size)]
44 | sampled_data = [self.data[ind] for ind in sampled_indices]
45 |
46 | return sampled_data
47 |
48 | def size(self):
49 | return len(self.data)
50 |
51 | def persist_memory(self, dir):
52 | for k in range(len(self.data)):
53 | transition = self.data[k]
54 | with open(os.path.join(dir, str(k)), "wb") as f:
55 | pickle.dump(transition, f)
56 |
57 | def clear(self):
58 | self.data = []
59 | self.pos = 0
60 |
61 | def get_buffer(self):
62 | return self.data
63 |
64 |
65 | class Agent:
66 | def __init__(self,
67 | exp_path,
68 | seed,
69 | env_fn,
70 | timeout,
71 | gamma,
72 | offline_data,
73 | action_dim,
74 | batch_size,
75 | use_target_network,
76 | target_network_update_freq,
77 | evaluation_criteria,
78 | logger
79 | ):
80 | self.exp_path = exp_path
81 | self.seed = seed
82 | self.use_target_network = use_target_network
83 | self.target_network_update_freq = target_network_update_freq
84 | self.parameters_dir = self.get_parameters_dir()
85 |
86 | self.batch_size = batch_size
87 | self.env = env_fn()
88 | self.eval_env = copy.deepcopy(env_fn)()
89 | self.offline_data = offline_data
90 | self.replay = Replay(memory_size=2000000, batch_size=batch_size, seed=seed)
91 | self.state_normalizer = lambda x: x
92 | self.evaluation_criteria = evaluation_criteria
93 | self.logger = logger
94 | self.timeout = timeout
95 | self.action_dim = action_dim
96 |
97 | self.gamma = gamma
98 | self.device = 'cpu'
99 | self.stats_queue_size = 5
100 | self.episode_reward = 0
101 | self.episode_rewards = []
102 | self.total_steps = 0
103 | self.reset = True
104 | self.ep_steps = 0
105 | self.num_episodes = 0
106 | self.ep_returns_queue_train = np.zeros(self.stats_queue_size)
107 | self.ep_returns_queue_test = np.zeros(self.stats_queue_size)
108 | self.train_stats_counter = 0
109 | self.test_stats_counter = 0
110 | self.agent_rng = np.random.RandomState(self.seed)
111 |
112 | self.populate_latest = False
113 | self.populate_states, self.populate_actions, self.populate_true_qs = None, None, None
114 | self.automatic_tmp_tuning = False
115 |
116 | self.state = None
117 | self.action = None
118 | self.next_state = None
119 | self.eps = 1e-8
120 |
121 | def get_parameters_dir(self):
122 | d = os.path.join(self.exp_path, "parameters")
123 | torch_utils.ensure_dir(d)
124 | return d
125 |
126 | def offline_param_init(self):
127 | self.trainset = self.training_set_construction(self.offline_data)
128 | self.training_size = len(self.trainset[0])
129 | self.training_indexs = np.arange(self.training_size)
130 |
131 | self.training_loss = []
132 | self.test_loss = []
133 | self.tloss_increase = 0
134 | self.tloss_rec = np.inf
135 |
136 | def get_data(self):
137 | states, actions, rewards, next_states, terminals = self.replay.sample()
138 | in_ = torch_utils.tensor(self.state_normalizer(states), self.device)
139 | r = torch_utils.tensor(rewards, self.device)
140 | ns = torch_utils.tensor(self.state_normalizer(next_states), self.device)
141 | t = torch_utils.tensor(terminals, self.device)
142 | data = {
143 | 'obs': in_,
144 | 'act': actions,
145 | 'reward': r,
146 | 'obs2': ns,
147 | 'done': t
148 | }
149 | return data
150 |
151 | def fill_offline_data_to_buffer(self):
152 | self.trainset = self.training_set_construction(self.offline_data)
153 | train_s, train_a, train_r, train_ns, train_t = self.trainset
154 | for idx in range(len(train_s)):
155 | self.replay.feed([train_s[idx], train_a[idx], train_r[idx], train_ns[idx], train_t[idx]])
156 |
157 | def step(self):
158 | # trans = self.feed_data()
159 | self.update_stats(0, None)
160 | data = self.get_data()
161 | losses = self.update(data)
162 | return losses
163 |
164 | def update(self, data):
165 | raise NotImplementedError
166 |
167 | def update_stats(self, reward, done):
168 | self.episode_reward += reward
169 | self.total_steps += 1
170 | self.ep_steps += 1
171 | if done or self.ep_steps == self.timeout:
172 | self.episode_rewards.append(self.episode_reward)
173 | self.num_episodes += 1
174 | if self.evaluation_criteria == "return":
175 | self.add_train_log(self.episode_reward)
176 | elif self.evaluation_criteria == "steps":
177 | self.add_train_log(self.ep_steps)
178 | else:
179 | raise NotImplementedError
180 | self.episode_reward = 0
181 | self.ep_steps = 0
182 | self.reset = True
183 |
184 | def add_train_log(self, ep_return):
185 | self.ep_returns_queue_train[self.train_stats_counter] = ep_return
186 | self.train_stats_counter += 1
187 | self.train_stats_counter = self.train_stats_counter % self.stats_queue_size
188 |
189 | def add_test_log(self, ep_return):
190 | self.ep_returns_queue_test[self.test_stats_counter] = ep_return
191 | self.test_stats_counter += 1
192 | self.test_stats_counter = self.test_stats_counter % self.stats_queue_size
193 |
194 | def populate_returns(self, log_traj=False, total_ep=None, initialize=False):
195 | total_ep = self.stats_queue_size if total_ep is None else total_ep
196 | total_steps = 0
197 | total_states = []
198 | total_actions = []
199 | total_returns = []
200 | for ep in range(total_ep):
201 | ep_return, steps, traj = self.eval_episode(log_traj=log_traj)
202 | total_steps += steps
203 | total_states += traj[0]
204 | total_actions += traj[1]
205 | total_returns += traj[2]
206 | if self.evaluation_criteria == "return":
207 | self.add_test_log(ep_return)
208 | if initialize:
209 | self.add_train_log(ep_return)
210 | elif self.evaluation_criteria == "steps":
211 | self.add_test_log(steps)
212 | if initialize:
213 | self.add_train_log(steps)
214 | else:
215 | raise NotImplementedError
216 | return [total_states, total_actions, total_returns]
217 |
218 | def eval_episode(self, log_traj=False):
219 | ep_traj = []
220 | state = self.eval_env.reset()
221 | total_rewards = 0
222 | ep_steps = 0
223 | done = False
224 | while True:
225 | action = self.eval_step(state)
226 | last_state = state
227 | state, reward, done, _ = self.eval_env.step([action])
228 | # print(np.abs(state-last_state).sum(), "\n",action)
229 | if log_traj:
230 | ep_traj.append([last_state, action, reward])
231 | total_rewards += reward
232 | ep_steps += 1
233 | if done or ep_steps == self.timeout:
234 | break
235 |
236 | states = []
237 | actions = []
238 | rets = []
239 | if log_traj:
240 | ret = 0
241 | for i in range(len(ep_traj)-1, -1, -1):
242 | s, a, r = ep_traj[i]
243 | ret = r + self.gamma * ret
244 | rets.insert(0, ret)
245 | actions.insert(0, a)
246 | states.insert(0, s)
247 | return total_rewards, ep_steps, [states, actions, rets]
248 |
249 | def log_return(self, log_ary, name, elapsed_time):
250 | rewards = log_ary
251 | total_episodes = len(self.episode_rewards)
252 | mean, median, min_, max_ = np.mean(rewards), np.median(rewards), np.min(rewards), np.max(rewards)
253 |
254 | log_str = '%s LOG: steps %d, episodes %3d, ' \
255 | 'returns %.2f/%.2f/%.2f/%.2f/%d (mean/median/min/max/num), %.2f steps/s'
256 |
257 | self.logger.info(log_str % (name, self.total_steps, total_episodes, mean, median,
258 | min_, max_, len(rewards),
259 | elapsed_time))
260 | return mean, median, min_, max_
261 |
262 | def log_file(self, elapsed_time=-1, test=True):
263 | mean, median, min_, max_ = self.log_return(self.ep_returns_queue_train, "TRAIN", elapsed_time)
264 | if test:
265 | self.populate_states, self.populate_actions, self.populate_true_qs = self.populate_returns(log_traj=True)
266 | self.populate_latest = True
267 | mean, median, min_, max_ = self.log_return(self.ep_returns_queue_test, "TEST", elapsed_time)
268 | try:
269 | normalized = np.array([self.eval_env.env.unwrapped.get_normalized_score(ret_) for ret_ in self.ep_returns_queue_test])
270 | mean, median, min_, max_ = self.log_return(normalized, "Normalized", elapsed_time)
271 | except:
272 | pass
273 | return mean, median, min_, max_
274 |
275 | def policy(self, o, eval=False):
276 | o = torch_utils.tensor(self.state_normalizer(o), self.device)
277 | with torch.no_grad():
278 | a, _ = self.ac.pi(o, deterministic=eval)
279 | a = torch_utils.to_np(a)
280 | return a
281 |
282 | def eval_step(self, state):
283 | a = self.policy(state, eval=True)
284 | return a
285 |
286 | def training_set_construction(self, data_dict):
287 | assert len(list(data_dict.keys())) == 1
288 | data_dict = data_dict[list(data_dict.keys())[0]]
289 | states = data_dict['states']
290 | actions = data_dict['actions']
291 | rewards = data_dict['rewards']
292 | next_states = data_dict['next_states']
293 | terminations = data_dict['terminations']
294 | return [states, actions, rewards, next_states, terminations]
295 |
--------------------------------------------------------------------------------
/core/agent/in_sample.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from core.agent import base
3 | from collections import namedtuple
4 | import os
5 | import torch
6 |
7 | from core.network.policy_factory import MLPCont, MLPDiscrete
8 | from core.network.network_architectures import DoubleCriticNetwork, DoubleCriticDiscrete, FCNetwork
9 |
10 | class InSampleAC(base.Agent):
11 | def __init__(self,
12 | device,
13 | discrete_control,
14 | state_dim,
15 | action_dim,
16 | hidden_units,
17 | learning_rate,
18 | tau,
19 | polyak,
20 | exp_path,
21 | seed,
22 | env_fn,
23 | timeout,
24 | gamma,
25 | offline_data,
26 | batch_size,
27 | use_target_network,
28 | target_network_update_freq,
29 | evaluation_criteria,
30 | logger
31 | ):
32 | super(InSampleAC, self).__init__(
33 | exp_path=exp_path,
34 | seed=seed,
35 | env_fn=env_fn,
36 | timeout=timeout,
37 | gamma=gamma,
38 | offline_data=offline_data,
39 | action_dim=action_dim,
40 | batch_size=batch_size,
41 | use_target_network=use_target_network,
42 | target_network_update_freq=target_network_update_freq,
43 | evaluation_criteria=evaluation_criteria,
44 | logger=logger
45 | )
46 |
47 | def get_policy_func():
48 | if discrete_control:
49 | pi = MLPDiscrete(device, state_dim, action_dim, [hidden_units]*2)
50 | else:
51 | pi = MLPCont(device, state_dim, action_dim, [hidden_units]*2)
52 | return pi
53 |
54 | def get_critic_func():
55 | if discrete_control:
56 | q1q2 = DoubleCriticDiscrete(device, state_dim, [hidden_units]*2, action_dim)
57 | else:
58 | q1q2 = DoubleCriticNetwork(device, state_dim, action_dim, [hidden_units]*2)
59 | return q1q2
60 |
61 | pi = get_policy_func()
62 | q1q2 = get_critic_func()
63 | AC = namedtuple('AC', ['q1q2', 'pi'])
64 | self.ac = AC(q1q2=q1q2, pi=pi)
65 | pi_target = get_policy_func()
66 | q1q2_target = get_critic_func()
67 | q1q2_target.load_state_dict(q1q2.state_dict())
68 | pi_target.load_state_dict(pi.state_dict())
69 | ACTarg = namedtuple('ACTarg', ['q1q2', 'pi'])
70 | self.ac_targ = ACTarg(q1q2=q1q2_target, pi=pi_target)
71 | self.ac_targ.q1q2.load_state_dict(self.ac.q1q2.state_dict())
72 | self.ac_targ.pi.load_state_dict(self.ac.pi.state_dict())
73 | self.beh_pi = get_policy_func()
74 | self.value_net = FCNetwork(device, np.prod(state_dim), [hidden_units]*2, 1)
75 |
76 | self.pi_optimizer = torch.optim.Adam(list(self.ac.pi.parameters()), learning_rate)
77 | self.q_optimizer = torch.optim.Adam(list(self.ac.q1q2.parameters()), learning_rate)
78 | self.value_optimizer = torch.optim.Adam(list(self.value_net.parameters()), learning_rate)
79 | self.beh_pi_optimizer = torch.optim.Adam(list(self.beh_pi.parameters()), learning_rate)
80 | self.exp_threshold = 10000
81 | if discrete_control:
82 | self.get_q_value = self.get_q_value_discrete
83 | self.get_q_value_target = self.get_q_value_target_discrete
84 | else:
85 | self.get_q_value = self.get_q_value_cont
86 | self.get_q_value_target = self.get_q_value_target_cont
87 |
88 | self.tau = tau
89 | self.polyak = polyak
90 | self.fill_offline_data_to_buffer()
91 | self.offline_param_init()
92 | return
93 |
94 |
95 | def compute_loss_beh_pi(self, data):
96 | """L_{\omega}, learn behavior policy"""
97 | states, actions = data['obs'], data['act']
98 | beh_log_probs = self.beh_pi.get_logprob(states, actions)
99 | beh_loss = -beh_log_probs.mean()
100 | return beh_loss, beh_log_probs
101 |
102 | def compute_loss_value(self, data):
103 | """L_{\phi}, learn z for state value, v = tau log z"""
104 | states = data['obs']
105 | v_phi = self.value_net(states).squeeze(-1)
106 | with torch.no_grad():
107 | actions, log_probs = self.ac.pi(states)
108 | min_Q, _, _ = self.get_q_value_target(states, actions)
109 | target = min_Q - self.tau * log_probs
110 | value_loss = (0.5 * (v_phi - target) ** 2).mean()
111 | return value_loss, v_phi.detach().numpy(), log_probs.detach().numpy()
112 |
113 | def get_state_value(self, state):
114 | with torch.no_grad():
115 | value = self.value_net(state).squeeze(-1)
116 | return value
117 |
118 | def compute_loss_q(self, data):
119 | states, actions, rewards, next_states, dones = data['obs'], data['act'], data['reward'], data['obs2'], data['done']
120 | with torch.no_grad():
121 | next_actions, log_probs = self.ac.pi(next_states)
122 | min_Q, _, _ = self.get_q_value_target(next_states, next_actions)
123 | q_target = rewards + self.gamma * (1 - dones) * (min_Q - self.tau * log_probs)
124 |
125 | minq, q1, q2 = self.get_q_value(states, actions, with_grad=True)
126 |
127 | critic1_loss = (0.5 * (q_target - q1) ** 2).mean()
128 | critic2_loss = (0.5 * (q_target - q2) ** 2).mean()
129 | loss_q = (critic1_loss + critic2_loss) * 0.5
130 | q_info = minq.detach().numpy()
131 | return loss_q, q_info
132 |
133 | def compute_loss_pi(self, data):
134 | """L_{\psi}, extract learned policy"""
135 | states, actions = data['obs'], data['act']
136 |
137 | log_probs = self.ac.pi.get_logprob(states, actions)
138 | min_Q, _, _ = self.get_q_value(states, actions, with_grad=False)
139 | with torch.no_grad():
140 | value = self.get_state_value(states)
141 | beh_log_prob = self.beh_pi.get_logprob(states, actions)
142 |
143 | clipped = torch.clip(torch.exp((min_Q - value) / self.tau - beh_log_prob), self.eps, self.exp_threshold)
144 | pi_loss = -(clipped * log_probs).mean()
145 | return pi_loss, ""
146 |
147 | def update_beta(self, data):
148 | loss_beh_pi, _ = self.compute_loss_beh_pi(data)
149 | self.beh_pi_optimizer.zero_grad()
150 | loss_beh_pi.backward()
151 | self.beh_pi_optimizer.step()
152 | return loss_beh_pi
153 |
154 | def update(self, data):
155 | loss_beta = self.update_beta(data).item()
156 |
157 | self.value_optimizer.zero_grad()
158 | loss_vs, v_info, logp_info = self.compute_loss_value(data)
159 | loss_vs.backward()
160 | self.value_optimizer.step()
161 |
162 | loss_q, qinfo = self.compute_loss_q(data)
163 | self.q_optimizer.zero_grad()
164 | loss_q.backward()
165 | self.q_optimizer.step()
166 |
167 | loss_pi, _ = self.compute_loss_pi(data)
168 | self.pi_optimizer.zero_grad()
169 | loss_pi.backward()
170 | self.pi_optimizer.step()
171 |
172 | if self.use_target_network and self.total_steps % self.target_network_update_freq == 0:
173 | self.sync_target()
174 |
175 | return {"beta": loss_beta,
176 | "actor": loss_pi.item(),
177 | "critic": loss_q.item(),
178 | "value": loss_vs.item(),
179 | "q_info": qinfo.mean(),
180 | "v_info": v_info.mean(),
181 | "logp_info": logp_info.mean(),
182 | }
183 |
184 |
185 | def get_q_value_discrete(self, o, a, with_grad=False):
186 | if with_grad:
187 | q1_pi, q2_pi = self.ac.q1q2(o)
188 | q1_pi, q2_pi = q1_pi[np.arange(len(a)), a], q2_pi[np.arange(len(a)), a]
189 | q_pi = torch.min(q1_pi, q2_pi)
190 | else:
191 | with torch.no_grad():
192 | q1_pi, q2_pi = self.ac.q1q2(o)
193 | q1_pi, q2_pi = q1_pi[np.arange(len(a)), a], q2_pi[np.arange(len(a)), a]
194 | q_pi = torch.min(q1_pi, q2_pi)
195 | return q_pi.squeeze(-1), q1_pi.squeeze(-1), q2_pi.squeeze(-1)
196 |
197 | def get_q_value_target_discrete(self, o, a):
198 | with torch.no_grad():
199 | q1_pi, q2_pi = self.ac_targ.q1q2(o)
200 | q1_pi, q2_pi = q1_pi[np.arange(len(a)), a], q2_pi[np.arange(len(a)), a]
201 | q_pi = torch.min(q1_pi, q2_pi)
202 | return q_pi.squeeze(-1), q1_pi.squeeze(-1), q2_pi.squeeze(-1)
203 |
204 | def get_q_value_cont(self, o, a, with_grad=False):
205 | if with_grad:
206 | q1_pi, q2_pi = self.ac.q1q2(o, a)
207 | q_pi = torch.min(q1_pi, q2_pi)
208 | else:
209 | with torch.no_grad():
210 | q1_pi, q2_pi = self.ac.q1q2(o, a)
211 | q_pi = torch.min(q1_pi, q2_pi)
212 | return q_pi.squeeze(-1), q1_pi.squeeze(-1), q2_pi.squeeze(-1)
213 |
214 | def get_q_value_target_cont(self, o, a):
215 | with torch.no_grad():
216 | q1_pi, q2_pi = self.ac_targ.q1q2(o, a)
217 | q_pi = torch.min(q1_pi, q2_pi)
218 | return q_pi.squeeze(-1), q1_pi.squeeze(-1), q2_pi.squeeze(-1)
219 |
220 | def sync_target(self):
221 | with torch.no_grad():
222 | for p, p_targ in zip(self.ac.q1q2.parameters(), self.ac_targ.q1q2.parameters()):
223 | p_targ.data.mul_(self.polyak)
224 | p_targ.data.add_((1 - self.polyak) * p.data)
225 | for p, p_targ in zip(self.ac.pi.parameters(), self.ac_targ.pi.parameters()):
226 | p_targ.data.mul_(self.polyak)
227 | p_targ.data.add_((1 - self.polyak) * p.data)
228 |
229 | def save(self):
230 | parameters_dir = self.parameters_dir
231 | path = os.path.join(parameters_dir, "actor_net")
232 | torch.save(self.ac.pi.state_dict(), path)
233 |
234 | path = os.path.join(parameters_dir, "critic_net")
235 | torch.save(self.ac.q1q2.state_dict(), path)
236 |
237 | path = os.path.join(parameters_dir, "vs_net")
238 | torch.save(self.value_net.state_dict(), path)
239 |
240 |
241 |
242 |
--------------------------------------------------------------------------------
/core/environment/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hwang-ua/inac_pytorch/ca5007bbd59cf53adf0cc588dc5130b836c30622/core/environment/__init__.py
--------------------------------------------------------------------------------
/core/environment/acrobot.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import numpy as np
4 | import gym
5 | import copy
6 |
7 | import core.utils.helpers
8 | from core.utils.torch_utils import random_seed
9 |
10 |
11 | class Acrobot:
12 | def __init__(self, seed=np.random.randint(int(1e5))):
13 | random_seed(seed)
14 | self.state_dim = (6,)
15 | self.action_dim = 3
16 | self.env = gym.make('Acrobot-v1')
17 | self.env._seed = seed
18 | self.env._max_episode_steps = np.inf # control timeout setting in agent
19 | self.state = None
20 |
21 | def generate_state(self, coords):
22 | return coords
23 |
24 | def reset(self):
25 | self.state = np.asarray(self.env.reset())
26 | return self.state
27 |
28 | def step(self, a):
29 | state, reward, done, info = self.env.step(a[0])
30 | self.state = state
31 | # self.env.render()
32 | return np.asarray(state), np.asarray(reward), np.asarray(done), info
33 |
34 | def get_visualization_segment(self):
35 | raise NotImplementedError
36 |
37 | def get_useful(self, state=None):
38 | if state:
39 | return state
40 | else:
41 | return np.array(self.env.state)
42 |
43 | def info(self, key):
44 | return
45 |
46 |
--------------------------------------------------------------------------------
/core/environment/ant.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ['D4RL_SUPPRESS_IMPORT_ERROR'] = '1'
3 |
4 | import gym
5 | import d4rl
6 | import numpy as np
7 |
8 | from core.utils.torch_utils import random_seed
9 |
10 |
11 | class Ant:
12 | def __init__(self, seed=np.random.randint(int(1e5))):
13 | random_seed(seed)
14 | self.state_dim = (111,)
15 | self.action_dim = 8
16 | # self.env = gym.make('Ant-v2')
17 | self.env = gym.make('ant-random-v2')# Loading d4rl env. For the convinience of getting normalized score from d4rl
18 | self.env.unwrapped.seed(seed)
19 | self.env._max_episode_steps = np.inf # control timeout setting in agent
20 | self.state = None
21 |
22 | def reset(self):
23 | return self.env.reset()
24 |
25 | def step(self, a):
26 | ret = self.env.step(a[0])
27 | state, reward, done, info = ret
28 | self.state = state
29 | # self.env.render()
30 | return np.asarray(state), np.asarray(reward), np.asarray(done), info
31 |
32 | def get_visualization_segment(self):
33 | raise NotImplementedError
34 |
35 | def get_useful(self, state=None):
36 | if state:
37 | return state
38 | else:
39 | return np.array(self.env.state)
40 |
41 | def info(self, key):
42 | return
43 |
--------------------------------------------------------------------------------
/core/environment/env_factory.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from core.environment.mountaincar import MountainCar
4 | from core.environment.acrobot import Acrobot
5 | from core.environment.lunarlander import LunarLander
6 | from core.environment.halfcheetah import HalfCheetah
7 | from core.environment.walker2d import Walker2d
8 | from core.environment.hopper import Hopper
9 | from core.environment.ant import Ant
10 |
11 | class EnvFactory:
12 | @classmethod
13 | def create_env_fn(cls, cfg):
14 | if cfg.env_name == 'MountainCar':
15 | return lambda: MountainCar(cfg.seed)
16 | elif cfg.env_name == 'Acrobot':
17 | return lambda: Acrobot(cfg.seed)
18 | elif cfg.env_name == 'LunarLander':
19 | return lambda: LunarLander(cfg.seed)
20 | elif cfg.env_name == 'HalfCheetah':
21 | return lambda: HalfCheetah(cfg.seed)
22 | elif cfg.env_name == 'Walker2d':
23 | return lambda: Walker2d(cfg.seed)
24 | elif cfg.env_name == 'Hopper':
25 | return lambda: Hopper(cfg.seed)
26 | elif cfg.env_name == 'Ant':
27 | return lambda: Ant(cfg.seed)
28 | else:
29 | print(cfg.env_name)
30 | raise NotImplementedError
--------------------------------------------------------------------------------
/core/environment/halfcheetah.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ['D4RL_SUPPRESS_IMPORT_ERROR'] = '1'
3 |
4 | import gym
5 | import d4rl
6 | import numpy as np
7 |
8 | from core.utils.torch_utils import random_seed
9 |
10 |
11 | class HalfCheetah:
12 | def __init__(self, seed=np.random.randint(int(1e5))):
13 | random_seed(seed)
14 | self.state_dim = (17,)
15 | self.action_dim = 6
16 | # self.env = gym.make('HalfCheetah-v2')
17 | self.env = gym.make('halfcheetah-random-v2') # Loading d4rl env. For the convinience of getting normalized score from d4rl
18 | self.env.unwrapped.seed(seed)
19 | self.env._max_episode_steps = np.inf # control timeout setting in agent
20 | self.state = None
21 |
22 | def reset(self):
23 | return self.env.reset()
24 |
25 | def step(self, a):
26 | ret = self.env.step(a[0])
27 | state, reward, done, info = ret
28 | self.state = state
29 | # self.env.render()
30 | return np.asarray(state), np.asarray(reward), np.asarray(done), info
31 |
32 | def get_visualization_segment(self):
33 | raise NotImplementedError
34 |
35 | def get_useful(self, state=None):
36 | if state:
37 | return state
38 | else:
39 | return np.array(self.env.state)
40 |
41 | def info(self, key):
42 | return
43 |
--------------------------------------------------------------------------------
/core/environment/hopper.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ['D4RL_SUPPRESS_IMPORT_ERROR'] = '1'
3 |
4 | import gym
5 | import d4rl
6 | import numpy as np
7 |
8 | from core.utils.torch_utils import random_seed
9 |
10 |
11 | class Hopper:
12 | def __init__(self, seed=np.random.randint(int(1e5))):
13 | random_seed(seed)
14 | self.state_dim = (11,)
15 | self.action_dim = 3
16 | # self.env = gym.make('Hopper-v2')
17 | self.env = gym.make('hopper-random-v2') # Loading d4rl env. For the convinience of getting normalized score from d4rl
18 | self.env.unwrapped.seed(seed)
19 | self.env._max_episode_steps = np.inf # control timeout setting in agent
20 | self.state = None
21 |
22 | def reset(self):
23 | return self.env.reset()
24 |
25 | def step(self, a):
26 | ret = self.env.step(a[0])
27 | state, reward, done, info = ret
28 | self.state = state
29 | # self.env.env.render()
30 | return np.asarray(state), np.asarray(reward), np.asarray(done), info
31 |
32 | def get_visualization_segment(self):
33 | raise NotImplementedError
34 |
35 | def get_useful(self, state=None):
36 | if state:
37 | return state
38 | else:
39 | return np.array(self.env.state)
40 |
41 | def info(self, key):
42 | return
43 |
--------------------------------------------------------------------------------
/core/environment/lunarlander.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import gym
3 | import copy
4 |
5 | from core.utils.torch_utils import random_seed
6 |
7 |
8 | class LunarLander:
9 | def __init__(self, seed=np.random.randint(int(1e5))):
10 | random_seed(seed)
11 | self.state_dim = (8,)
12 | self.action_dim = 4
13 | self.env = gym.make('LunarLander-v2')
14 | self.env._seed = seed
15 | self.env._max_episode_steps = np.inf # control timeout setting in agent
16 |
17 | def generate_state(self, coords):
18 | return coords
19 |
20 | def reset(self):
21 | return np.asarray(self.env.reset())
22 |
23 | def step(self, a):
24 | state, reward, done, info = self.env.step(a[0])
25 | # self.env.render()
26 | return np.asarray(state), np.asarray(reward), np.asarray(done), info
27 |
28 | def get_visualization_segment(self):
29 | raise NotImplementedError
30 |
31 | def get_useful(self, state=None):
32 | if state:
33 | return state
34 | else:
35 | return np.array(self.env.state)
36 |
37 | def info(self, key):
38 | return
39 |
--------------------------------------------------------------------------------
/core/environment/mountaincar.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import gym
3 | import copy
4 |
5 | from core.utils.torch_utils import random_seed
6 |
7 |
8 | class MountainCar:
9 | def __init__(self, seed=np.random.randint(int(1e5))):
10 | random_seed(seed)
11 | self.state_dim = (2,)
12 | self.action_dim = 3
13 | self.env = gym.make('MountainCar-v0')
14 | self.env._seed = seed
15 | self.env._max_episode_steps = np.inf # control timeout setting in agent
16 |
17 | def generate_state(self, coords):
18 | return coords
19 |
20 | def reset(self):
21 | return np.asarray(self.env.reset())
22 |
23 | def step(self, a):
24 | state, reward, done, info = self.env.step(a[0])
25 | # self.env.render()
26 | return np.asarray(state), np.asarray(reward), np.asarray(done), info
27 |
28 | def get_visualization_segment(self):
29 | raise NotImplementedError
30 |
31 | def get_useful(self, state=None):
32 | if state:
33 | return state
34 | else:
35 | return np.array(self.env.state)
36 |
37 | def info(self, key):
38 | return
39 |
40 |
--------------------------------------------------------------------------------
/core/environment/walker2d.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ['D4RL_SUPPRESS_IMPORT_ERROR'] = '1'
3 |
4 | import gym
5 | import d4rl
6 | import numpy as np
7 |
8 | from core.utils.torch_utils import random_seed
9 |
10 |
11 | class Walker2d:
12 | def __init__(self, seed=np.random.randint(int(1e5))):
13 | random_seed(seed)
14 | self.state_dim = (17,)
15 | self.action_dim = 6
16 | # self.env = gym.make('Walker2d-v2')
17 | self.env = gym.make('walker2d-random-v2')# Loading d4rl env. For the convinience of getting normalized score from d4rl
18 | self.env.unwrapped.seed(seed)
19 | self.env._max_episode_steps = np.inf # control timeout setting in agent
20 | self.state = None
21 |
22 | def reset(self):
23 | return self.env.reset()
24 |
25 | def step(self, a):
26 | ret = self.env.step(a[0])
27 | state, reward, done, info = ret
28 | self.state = state
29 | # self.env.render()
30 | return np.asarray(state), np.asarray(reward), np.asarray(done), info
31 |
32 | def get_visualization_segment(self):
33 | raise NotImplementedError
34 |
35 | def get_useful(self, state=None):
36 | if state:
37 | return state
38 | else:
39 | return np.array(self.env.state)
40 |
41 | def info(self, key):
42 | return
43 |
--------------------------------------------------------------------------------
/core/network/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hwang-ua/inac_pytorch/ca5007bbd59cf53adf0cc588dc5130b836c30622/core/network/__init__.py
--------------------------------------------------------------------------------
/core/network/network_architectures.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as functional
5 |
6 | from core.network import network_utils, network_bodies
7 | from core.utils import torch_utils
8 |
9 |
10 | class FCNetwork(nn.Module):
11 | def __init__(self, device, input_units, hidden_units, output_units, head_activation=lambda x:x):
12 | super().__init__()
13 | body = network_bodies.FCBody(device, input_units, hidden_units=tuple(hidden_units), init_type='xavier')
14 | self.body = body
15 | self.fc_head = network_utils.layer_init_xavier(nn.Linear(body.feature_dim, output_units, bias=True), bias=True)
16 | self.device = device
17 | self.head_activation = head_activation
18 | self.to(device)
19 |
20 | def forward(self, x):
21 | if not isinstance(x, torch.Tensor): x = torch_utils.tensor(x, self.device)
22 | if len(x.shape) > 2: x = x.view(x.shape[0], -1)
23 | y = self.body(x)
24 | y = self.fc_head(y)
25 | y = self.head_activation(y)
26 | return y
27 |
28 | class DoubleCriticDiscrete(nn.Module):
29 | def __init__(self, device, input_units, hidden_units, output_units):
30 | super().__init__()
31 | self.device = device
32 | self.q1_net = FCNetwork(device, input_units, hidden_units, output_units)
33 | self.q2_net = FCNetwork(device, input_units, hidden_units, output_units)
34 |
35 | # def forward(self, x, a):
36 | def forward(self, x):
37 | if not isinstance(x, torch.Tensor): x = torch_utils.tensor(x, self.device)
38 | recover_size = False
39 | if len(x.size()) == 1:
40 | recover_size = True
41 | x = x.reshape((1, -1))
42 | q1 = self.q1_net(x)
43 | q2 = self.q2_net(x)
44 | if recover_size:
45 | q1 = q1[0]
46 | q2 = q2[0]
47 | return q1, q2
48 |
49 |
50 | class DoubleCriticNetwork(nn.Module):
51 | def __init__(self, device, num_inputs, num_actions, hidden_units):
52 | super(DoubleCriticNetwork, self).__init__()
53 | self.device = device
54 |
55 | # Q1 architecture
56 | self.body1 = network_bodies.FCBody(device, num_inputs + num_actions, hidden_units=tuple(hidden_units))
57 | self.head1 = network_utils.layer_init_xavier(nn.Linear(self.body1.feature_dim, 1))
58 | # Q2 architecture
59 | self.body2 = network_bodies.FCBody(device, num_inputs + num_actions, hidden_units=tuple(hidden_units))
60 | self.head2 = network_utils.layer_init_xavier(nn.Linear(self.body2.feature_dim, 1))
61 |
62 | def forward(self, state, action):
63 | if not isinstance(state, torch.Tensor): state = torch_utils.tensor(state, self.device)
64 | recover_size = False
65 | if len(state.shape) > 2:
66 | state = state.view(state.shape[0], -1)
67 | action = action.view(action.shape[0], -1)
68 | elif len(state.shape) == 1:
69 | state = state.view(1, -1)
70 | action = action.view(1, -1)
71 | recover_size = True
72 | if not isinstance(action, torch.Tensor): action = torch_utils.tensor(action, self.device)
73 |
74 | xu = torch.cat([state, action], 1)
75 |
76 | q1 = self.head1(self.body1(xu))
77 | q2 = self.head2(self.body2(xu))
78 |
79 | if recover_size:
80 | q1 = q1[0]
81 | q2 = q2[0]
82 | return q1, q2
83 |
84 |
--------------------------------------------------------------------------------
/core/network/network_bodies.py:
--------------------------------------------------------------------------------
1 | from functools import reduce
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as functional
6 |
7 | from core.network import network_utils
8 |
9 | class FCBody(nn.Module):
10 | def __init__(self, device, input_dim, hidden_units=(64, 64), activation=functional.relu, init_type='xavier', info=None):
11 | super().__init__()
12 | self.to(device)
13 | self.device = device
14 | dims = (input_dim,) + hidden_units
15 | self.layers = nn.ModuleList([network_utils.layer_init_xavier(nn.Linear(dim_in, dim_out).to(device)) for dim_in, dim_out in zip(dims[:-1], dims[1:])])
16 |
17 | if init_type == "xavier":
18 | self.layers = nn.ModuleList([network_utils.layer_init_xavier(nn.Linear(dim_in, dim_out).to(device)) for dim_in, dim_out in zip(dims[:-1], dims[1:])])
19 | elif init_type == "uniform":
20 | self.layers = nn.ModuleList([network_utils.layer_init_uniform(nn.Linear(dim_in, dim_out).to(device)) for dim_in, dim_out in zip(dims[:-1], dims[1:])])
21 | elif init_type == "zeros":
22 | self.layers = nn.ModuleList([network_utils.layer_init_zero(nn.Linear(dim_in, dim_out).to(device)) for dim_in, dim_out in zip(dims[:-1], dims[1:])])
23 | elif init_type == "constant":
24 | self.layers = nn.ModuleList([network_utils.layer_init_constant(nn.Linear(dim_in, dim_out).to(device), const=info) for dim_in, dim_out in zip(dims[:-1], dims[1:])])
25 | else:
26 | raise ValueError('init_type is not defined: {}'.format(init_type))
27 |
28 | self.activation = activation
29 | self.feature_dim = dims[-1]
30 |
31 | def forward(self, x):
32 | for layer in self.layers:
33 | x = self.activation(layer(x))
34 | return x
35 |
36 | def compute_lipschitz_upper(self):
37 | return [np.linalg.norm(layer.weight.detach().cpu().numpy(), ord=2) for layer in self.layers]
38 |
39 |
40 | class ConvBody(nn.Module):
41 | def __init__(self, device, state_dim, architecture):
42 | super().__init__()
43 |
44 | def size(size, kernel_size=3, stride=1, padding=0):
45 | return (size + 2 * padding - (kernel_size - 1) - 1) // stride + 1
46 |
47 | spatial_length, _, in_channels = state_dim
48 | num_units = None
49 | layers = nn.ModuleList()
50 | for layer_cfg in architecture['conv_layers']:
51 | layers.append(nn.Conv2d(layer_cfg["in"], layer_cfg["out"], layer_cfg["kernel"],
52 | layer_cfg["stride"], layer_cfg["pad"]))
53 | if not num_units:
54 | num_units = size(spatial_length, layer_cfg["kernel"], layer_cfg["stride"], layer_cfg["pad"])
55 | else:
56 | num_units = size(num_units, layer_cfg["kernel"], layer_cfg["stride"], layer_cfg["pad"])
57 | num_units = num_units ** 2 * architecture["conv_layers"][-1]["out"]
58 |
59 | self.feature_dim = num_units
60 | self.spatial_length = spatial_length
61 | self.in_channels = in_channels
62 | self.layers = layers
63 | self.to(device)
64 | self.device = device
65 |
66 | def forward(self, x):
67 | x = functional.relu(self.layers[0](self.shape_image(x)))
68 | for idx, layer in enumerate(self.layers[1:]):
69 | x = functional.relu(layer(x))
70 | # return x.view(x.size(0), -1)
71 | return x.reshape(x.size(0), -1)
72 |
--------------------------------------------------------------------------------
/core/network/network_utils.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | def layer_init(layer, w_scale=1.0):
5 | nn.init.orthogonal_(layer.weight.data)
6 | layer.weight.data.mul_(w_scale)
7 | nn.init.constant_(layer.bias.data, 0)
8 | return layer
9 |
10 |
11 | def layer_init_zero(layer, bias=True):
12 | nn.init.constant_(layer.weight, 0)
13 | if bias:
14 | nn.init.constant_(layer.bias.data, 0)
15 | return layer
16 |
17 | def layer_init_constant(layer, const, bias=True):
18 | nn.init.constant_(layer.weight, const)
19 | if bias:
20 | nn.init.constant_(layer.bias.data, const)
21 | return layer
22 |
23 |
24 | def layer_init_xavier(layer, bias=True):
25 | nn.init.xavier_uniform_(layer.weight)
26 | if bias:
27 | nn.init.constant_(layer.bias.data, 0)
28 | return layer
29 |
30 | def layer_init_uniform(layer, low=-0.003, high=0.003, bias=0):
31 | nn.init.uniform_(layer.weight, low, high)
32 | if not (type(bias)==bool and bias==False):
33 | nn.init.constant_(layer.bias.data, bias)
34 | return layer
35 |
--------------------------------------------------------------------------------
/core/network/policy_factory.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torch.distributions import Normal
6 | from torch.distributions import Categorical
7 |
8 | from core.network import network_utils, network_bodies
9 | from core.utils import torch_utils
10 |
11 |
12 | class MLPCont(nn.Module):
13 | def __init__(self, device, obs_dim, act_dim, hidden_sizes, action_range=1.0, init_type='xavier'):
14 | super().__init__()
15 | self.device = device
16 | body = network_bodies.FCBody(device, obs_dim, hidden_units=tuple(hidden_sizes), init_type=init_type)
17 | body_out = obs_dim if hidden_sizes==[] else hidden_sizes[-1]
18 | self.body = body
19 | self.mu_layer = network_utils.layer_init_xavier(nn.Linear(body_out, act_dim))
20 | self.log_std_logits = nn.Parameter(torch.zeros(act_dim, requires_grad=True))
21 | self.min_log_std = -6
22 | self.max_log_std = 0
23 | self.action_range = action_range
24 |
25 | """https://github.com/hari-sikchi/AWAC/blob/3ad931ec73101798ffe82c62b19313a8607e4f1e/core.py#L91"""
26 | def forward(self, obs, deterministic=False):
27 | if not isinstance(obs, torch.Tensor): obs = torch_utils.tensor(obs, self.device)
28 | recover_size = False
29 | if len(obs.size()) == 1:
30 | recover_size = True
31 | obs = obs.reshape((1, -1))
32 | net_out = self.body(obs)
33 | mu = self.mu_layer(net_out)
34 | mu = torch.tanh(mu) * self.action_range
35 |
36 | log_std = torch.sigmoid(self.log_std_logits)
37 | log_std = self.min_log_std + log_std * (self.max_log_std - self.min_log_std)
38 | std = torch.exp(log_std)
39 | pi_distribution = Normal(mu, std)
40 | if deterministic:
41 | pi_action = mu
42 | else:
43 | pi_action = pi_distribution.rsample()
44 | logp_pi = pi_distribution.log_prob(pi_action).sum(axis=-1)
45 |
46 | if recover_size:
47 | pi_action, logp_pi = pi_action[0], logp_pi[0]
48 | return pi_action, logp_pi
49 |
50 | def get_logprob(self, obs, actions):
51 | if not isinstance(obs, torch.Tensor): obs = torch_utils.tensor(obs, self.device)
52 | if not isinstance(actions, torch.Tensor): actions = torch_utils.tensor(actions, self.device)
53 | net_out = self.body(obs)
54 | mu = self.mu_layer(net_out)
55 | mu = torch.tanh(mu) * self.action_range
56 | log_std = torch.sigmoid(self.log_std_logits)
57 | # log_std = self.log_std_layer(net_out)
58 | log_std = self.min_log_std + log_std * (
59 | self.max_log_std - self.min_log_std)
60 | std = torch.exp(log_std)
61 | pi_distribution = Normal(mu, std)
62 | logp_pi = pi_distribution.log_prob(actions).sum(axis=-1)
63 | return logp_pi
64 |
65 |
66 | class MLPDiscrete(nn.Module):
67 | def __init__(self, device, obs_dim, act_dim, hidden_sizes, init_type='xavier'):
68 | super().__init__()
69 | self.device = device
70 | body = network_bodies.FCBody(device, obs_dim, hidden_units=tuple(hidden_sizes), init_type=init_type)
71 | body_out = obs_dim if hidden_sizes==[] else hidden_sizes[-1]
72 | self.body = body
73 | self.mu_layer = network_utils.layer_init_xavier(nn.Linear(body_out, act_dim))
74 | self.log_std_logits = nn.Parameter(torch.zeros(act_dim, requires_grad=True))
75 | self.min_log_std = -6
76 | self.max_log_std = 0
77 |
78 | def forward(self, obs, deterministic=True):
79 | if not isinstance(obs, torch.Tensor): obs = torch_utils.tensor(obs, self.device)
80 | recover_size = False
81 | if len(obs.size()) == 1:
82 | recover_size = True
83 | obs = obs.reshape((1, -1))
84 | net_out = self.body(obs)
85 | probs = self.mu_layer(net_out)
86 | probs = F.softmax(probs, dim=1)
87 | m = Categorical(probs)
88 | action = m.sample()
89 | logp = m.log_prob(action)
90 | if recover_size:
91 | action, logp = action[0], logp[0]
92 | return action, logp
93 |
94 | def get_logprob(self, obs, actions):
95 | if not isinstance(obs, torch.Tensor): obs = torch_utils.tensor(obs, self.device)
96 | if not isinstance(actions, torch.Tensor): actions = torch_utils.tensor(actions, self.device)
97 | net_out = self.body(obs)
98 | probs = self.mu_layer(net_out)
99 | probs = F.softmax(probs, dim=1)
100 | m = Categorical(probs)
101 | logp_pi = m.log_prob(actions)
102 | return logp_pi
103 |
--------------------------------------------------------------------------------
/core/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hwang-ua/inac_pytorch/ca5007bbd59cf53adf0cc588dc5130b836c30622/core/utils/__init__.py
--------------------------------------------------------------------------------
/core/utils/helpers.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | def common_member(a, b):
5 | a_set = set(a)
6 | b_set = set(b)
7 | if (a_set & b_set):
8 | return True
9 | else:
10 | return False
11 |
12 | def arcradians(cos, sin):
13 | if cos > 0 and sin > 0:
14 | return np.arccos(cos)
15 | elif cos > 0 and sin < 0:
16 | return np.arcsin(sin)
17 | elif cos < 0 and sin > 0:
18 | return np.arccos(cos)
19 | elif cos < 0 and sin < 0:
20 | return -1 * np.arccos(cos)
21 |
22 |
23 | def normalize_rows(x):
24 | return x / np.linalg.norm(x, ord=2, axis=1, keepdims=True)
25 |
26 | def copy_row(x, num_rows):
27 | return np.multiply(np.ones((num_rows, 1)), x)
28 |
29 | def expectile_loss(diff, expectile=0.8):
30 | weight = torch.where(diff > 0, expectile, (1 - expectile))
31 | return weight * (diff ** 2)
32 |
33 | def search_same_row(matrix, target_row):
34 | idx = np.where(np.all(matrix == target_row, axis=1))
35 | return idx
--------------------------------------------------------------------------------
/core/utils/logger.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | import logging
5 |
6 | # from tensorboardX import SummaryWriter
7 |
8 | def log_config(cfg):
9 | def get_print_attrs(cfg):
10 | attrs = dict(cfg.__dict__)
11 | for k in ['logger', 'env_fn', 'offline_data']:
12 | del attrs[k]
13 | return attrs
14 | attrs = get_print_attrs(cfg)
15 | for param, value in attrs.items():
16 | cfg.logger.info('{}: {}'.format(param, value))
17 |
18 |
19 | class Logger:
20 | def __init__(self, config, log_dir):
21 | log_file = os.path.join(log_dir, 'log')
22 | self._logger = logging.getLogger()
23 |
24 | file_handler = logging.FileHandler(log_file, mode='w')
25 | formatter = logging.Formatter('%(asctime)s | %(message)s')
26 | file_handler.setFormatter(formatter)
27 | self._logger.addHandler(file_handler)
28 |
29 | stream_handler = logging.StreamHandler(sys.stdout)
30 | stream_handler.setFormatter(formatter)
31 | self._logger.addHandler(stream_handler)
32 |
33 | self._logger.setLevel(level=logging.INFO)
34 |
35 | self.config = config
36 | # if config.tensorboard_logs: self.tensorboard_writer = SummaryWriter(config.get_log_dir())
37 |
38 | def info(self, log_msg):
39 | self._logger.info(log_msg)
--------------------------------------------------------------------------------
/core/utils/run_funcs.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import time
3 | import copy
4 | import numpy as np
5 |
6 | import os
7 | os.environ['D4RL_SUPPRESS_IMPORT_ERROR'] = '1'
8 | import gym
9 | import d4rl
10 | import gzip
11 |
12 | EARLYCUTOFF = "EarlyCutOff"
13 |
14 |
15 | def load_testset(env_name, dataset, id):
16 | path = None
17 | if env_name == 'HalfCheetah':
18 | if dataset == 'expert':
19 | path = {"env": "halfcheetah-expert-v2"}
20 | elif dataset == 'medexp':
21 | path = {"env": "halfcheetah-medium-expert-v2"}
22 | elif dataset == 'medium':
23 | path = {"env": "halfcheetah-medium-v2"}
24 | elif dataset == 'medrep':
25 | path = {"env": "halfcheetah-medium-replay-v2"}
26 | elif env_name == 'Walker2d':
27 | if dataset == 'expert':
28 | path = {"env": "walker2d-expert-v2"}
29 | elif dataset == 'medexp':
30 | path = {"env": "walker2d-medium-expert-v2"}
31 | elif dataset == 'medium':
32 | path = {"env": "walker2d-medium-v2"}
33 | elif dataset == 'medrep':
34 | path = {"env": "walker2d-medium-replay-v2"}
35 | elif env_name == 'Hopper':
36 | if dataset == 'expert':
37 | path = {"env": "hopper-expert-v2"}
38 | elif dataset == 'medexp':
39 | path = {"env": "hopper-medium-expert-v2"}
40 | elif dataset == 'medium':
41 | path = {"env": "hopper-medium-v2"}
42 | elif dataset == 'medrep':
43 | path = {"env": "hopper-medium-replay-v2"}
44 | elif env_name == 'Ant':
45 | if dataset == 'expert':
46 | path = {"env": "ant-expert-v2"}
47 | elif dataset == 'medexp':
48 | path = {"env": "ant-medium-expert-v2"}
49 | elif dataset == 'medium':
50 | path = {"env": "ant-medium-v2"}
51 | elif dataset == 'medrep':
52 | path = {"env": "ant-medium-replay-v2"}
53 |
54 | elif env_name == 'Acrobot':
55 | if dataset == 'expert':
56 | path = {"pkl": "data/dataset/acrobot/transitions_50k/train_40k/{}_run.pkl".format(id)}
57 | elif dataset == 'mixed':
58 | path = {"pkl": "data/dataset/acrobot/transitions_50k/train_mixed/{}_run.pkl".format(id)}
59 | elif env_name == 'LunarLander':
60 | if dataset == 'expert':
61 | path = {"pkl": "data/dataset/lunar_lander/transitions_50k/train_500k/{}_run.pkl".format(id)}
62 | elif dataset == 'mixed':
63 | path = {"pkl": "data/dataset/lunar_lander/transitions_50k/train_mixed/{}_run.pkl".format(id)}
64 | elif env_name == 'MountainCar':
65 | if dataset == 'expert':
66 | path = {"pkl": "data/dataset/mountain_car/transitions_50k/train_60k/{}_run.pkl".format(id)}
67 | elif dataset == 'mixed':
68 | path = {"pkl": "data/dataset/mountain_car/transitions_50k/train_mixed/{}_run.pkl".format(id)}
69 |
70 | assert path is not None
71 | testsets = {}
72 | for name in path:
73 | if name == "env":
74 | env = gym.make(path['env'])
75 | try:
76 | data = env.get_dataset()
77 | except:
78 | env = env.unwrapped
79 | data = env.get_dataset()
80 | testsets[name] = {
81 | 'states': data['observations'],
82 | 'actions': data['actions'],
83 | 'rewards': data['rewards'],
84 | 'next_states': data['next_observations'],
85 | 'terminations': data['terminals'],
86 | }
87 | else:
88 | pth = path[name]
89 | with open(pth.format(id), 'rb') as f:
90 | testsets[name] = pickle.load(f)
91 |
92 | return testsets
93 | else:
94 | return {}
95 |
96 | def run_steps(agent, max_steps, log_interval, eval_pth):
97 | t0 = time.time()
98 | evaluations = []
99 | agent.populate_returns(initialize=True)
100 | while True:
101 | if log_interval and not agent.total_steps % log_interval:
102 | mean, median, min_, max_ = agent.log_file(elapsed_time=log_interval / (time.time() - t0), test=True)
103 | evaluations.append(mean)
104 | t0 = time.time()
105 | if max_steps and agent.total_steps >= max_steps:
106 | break
107 | agent.step()
108 | agent.save()
109 | np.save(eval_pth+"/evaluations.npy", np.array(evaluations))
--------------------------------------------------------------------------------
/core/utils/torch_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch
4 |
5 |
6 | def tensor(x, device):
7 | if isinstance(x, torch.Tensor):
8 | return x
9 | x = torch.tensor(x, dtype=torch.float32).to(device)
10 | return x
11 |
12 | def to_np(t):
13 | return t.cpu().detach().numpy()
14 |
15 | def random_seed(seed):
16 | np.random.seed(seed)
17 | torch.manual_seed(seed)
18 |
19 | def set_one_thread():
20 | os.environ['OMP_NUM_THREADS'] = '1'
21 | os.environ['MKL_NUM_THREADS'] = '1'
22 | torch.set_num_threads(1)
23 |
24 | def ensure_dir(d):
25 | if not os.path.exists(d):
26 | os.makedirs(d)
27 |
--------------------------------------------------------------------------------
/img/after_fix.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hwang-ua/inac_pytorch/ca5007bbd59cf53adf0cc588dc5130b836c30622/img/after_fix.png
--------------------------------------------------------------------------------
/run_ac_offline.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 |
4 | import core.environment.env_factory as environment
5 | from core.utils import torch_utils, logger, run_funcs
6 | from core.agent.in_sample import *
7 |
8 |
9 | if __name__ == '__main__':
10 | parser = argparse.ArgumentParser(description="run_file")
11 | parser.add_argument('--seed', default=0, type=int)
12 | parser.add_argument('--env_name', default='Ant', type=str)
13 | parser.add_argument('--dataset', default='medexp', type=str)
14 | parser.add_argument('--discrete_control', default=0, type=int)
15 | parser.add_argument('--state_dim', default=1, type=int)
16 | parser.add_argument('--action_dim', default=1, type=int)
17 | parser.add_argument('--tau', default=0.1, type=float)
18 |
19 | parser.add_argument('--max_steps', default=1000000, type=int)
20 | parser.add_argument('--log_interval', default=10000, type=int)
21 | parser.add_argument('--learning_rate', default=3e-4, type=float)
22 | parser.add_argument('--hidden_units', default=256, type=int)
23 | parser.add_argument('--batch_size', default=256, type=int)
24 | parser.add_argument('--timeout', default=1000, type=int)
25 | parser.add_argument('--gamma', default=0.99, type=float)
26 | parser.add_argument('--use_target_network', default=1, type=int)
27 | parser.add_argument('--target_network_update_freq', default=1, type=int)
28 | parser.add_argument('--polyak', default=0.995, type=float)
29 | parser.add_argument('--evaluation_criteria', default='return', type=str)
30 | parser.add_argument('--device', default='cpu', type=str)
31 | parser.add_argument('--info', default='0', type=str)
32 | cfg = parser.parse_args()
33 |
34 | torch_utils.set_one_thread()
35 |
36 | torch_utils.random_seed(cfg.seed)
37 |
38 | project_root = os.path.abspath(os.path.dirname(__file__))
39 | exp_path = "data/output/{}/{}/{}/{}_run".format(cfg.env_name, cfg.dataset, cfg.info, cfg.seed)
40 | cfg.exp_path = os.path.join(project_root, exp_path)
41 | torch_utils.ensure_dir(cfg.exp_path)
42 | cfg.env_fn = environment.EnvFactory.create_env_fn(cfg)
43 | cfg.offline_data = run_funcs.load_testset(cfg.env_name, cfg.dataset, cfg.seed)
44 |
45 | # Setting up the logger
46 | cfg.logger = logger.Logger(cfg, cfg.exp_path)
47 | logger.log_config(cfg)
48 |
49 | # Initializing the agent and running the experiment
50 | agent_obj = InSampleAC(
51 | device=cfg.device,
52 | discrete_control=cfg.discrete_control,
53 | state_dim=cfg.state_dim,
54 | action_dim=cfg.action_dim,
55 | hidden_units=cfg.hidden_units,
56 | learning_rate=cfg.learning_rate,
57 | tau=cfg.tau,
58 | polyak=cfg.polyak,
59 | exp_path=cfg.exp_path,
60 | seed=cfg.seed,
61 | env_fn=cfg.env_fn,
62 | timeout=cfg.timeout,
63 | gamma=cfg.gamma,
64 | offline_data=cfg.offline_data,
65 | batch_size=cfg.batch_size,
66 | use_target_network=cfg.use_target_network,
67 | target_network_update_freq=cfg.target_network_update_freq,
68 | evaluation_criteria=cfg.evaluation_criteria,
69 | logger=cfg.logger
70 | )
71 | run_funcs.run_steps(agent_obj, cfg.max_steps, cfg.log_interval, exp_path)
--------------------------------------------------------------------------------