├── real-anymal.png
├── anymal-beliefs.png
├── anymal-teacher-student.png
├── anymal_belief_state_encoder_decoder_pytorch
├── __init__.py
├── running.py
├── trainer.py
├── ppo.py
└── networks.py
├── setup.py
├── LICENSE
├── .github
└── workflows
│ └── python-publish.yml
├── .gitignore
└── README.md
/real-anymal.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/anymal-belief-state-encoder-decoder-pytorch/HEAD/real-anymal.png
--------------------------------------------------------------------------------
/anymal-beliefs.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/anymal-belief-state-encoder-decoder-pytorch/HEAD/anymal-beliefs.png
--------------------------------------------------------------------------------
/anymal-teacher-student.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/anymal-belief-state-encoder-decoder-pytorch/HEAD/anymal-teacher-student.png
--------------------------------------------------------------------------------
/anymal_belief_state_encoder_decoder_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from anymal_belief_state_encoder_decoder_pytorch.networks import Student, Teacher, MLP, Anymal
2 | from anymal_belief_state_encoder_decoder_pytorch.ppo import PPO, MockEnv
3 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name = 'anymal-belief-state-encoder-decoder-pytorch',
5 | packages = find_packages(exclude=[]),
6 | version = '0.0.21',
7 | license='MIT',
8 | description = 'Anymal Belief-state Encoder Decoder - Pytorch',
9 | author = 'Phil Wang',
10 | author_email = 'lucidrains@gmail.com',
11 | url = 'https://github.com/lucidrains/anymal-belief-state-encoder-decoder-pytorch',
12 | keywords = [
13 | 'artificial intelligence',
14 | 'deep learning',
15 | 'attention gating',
16 | 'belief state',
17 | 'robotics'
18 | ],
19 | install_requires=[
20 | 'assoc-scan>=0.0.2',
21 | 'einops>=0.8',
22 | 'evolutionary-policy-optimization>=0.0.61',
23 | 'torch>=2.2',
24 | ],
25 | classifiers=[
26 | 'Development Status :: 4 - Beta',
27 | 'Intended Audience :: Developers',
28 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
29 | 'License :: OSI Approved :: MIT License',
30 | 'Programming Language :: Python :: 3.6',
31 | ],
32 | )
33 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Phil Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 |
2 |
3 | # This workflow will upload a Python Package using Twine when a release is created
4 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
5 |
6 | # This workflow uses actions that are not certified by GitHub.
7 | # They are provided by a third-party and are governed by
8 | # separate terms of service, privacy policy, and support
9 | # documentation.
10 |
11 | name: Upload Python Package
12 |
13 | on:
14 | release:
15 | types: [published]
16 |
17 | jobs:
18 | deploy:
19 |
20 | runs-on: ubuntu-latest
21 |
22 | steps:
23 | - uses: actions/checkout@v2
24 | - name: Set up Python
25 | uses: actions/setup-python@v2
26 | with:
27 | python-version: '3.x'
28 | - name: Install dependencies
29 | run: |
30 | python -m pip install --upgrade pip
31 | pip install build
32 | - name: Build package
33 | run: python -m build
34 | - name: Publish package
35 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
36 | with:
37 | user: __token__
38 | password: ${{ secrets.PYPI_API_TOKEN }}
39 |
--------------------------------------------------------------------------------
/anymal_belief_state_encoder_decoder_pytorch/running.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | class RunningStats(nn.Module):
5 | def __init__(self, shape, eps = 1e-5):
6 | super().__init__()
7 | shape = shape if isinstance(shape, tuple) else (shape,)
8 |
9 | self.shape = shape
10 | self.eps = eps
11 | self.n = 0
12 |
13 | self.register_buffer('old_mean', torch.zeros(shape), persistent = False)
14 | self.register_buffer('new_mean', torch.zeros(shape), persistent = False)
15 | self.register_buffer('old_std', torch.zeros(shape), persistent = False)
16 | self.register_buffer('new_std', torch.zeros(shape), persistent = False)
17 |
18 | def clear(self):
19 | self.n = 0
20 |
21 | def push(self, x):
22 | self.n += 1
23 |
24 | if self.n == 1:
25 | self.old_mean.copy_(x.data)
26 | self.new_mean.copy_(x.data)
27 | self.old_std.zero_()
28 | self.new_std.zero_()
29 | return
30 |
31 | self.new_mean.copy_(self.old_mean + (x - self.old_mean) / self.n)
32 | self.new_std.copy_(self.old_std + (x - self.old_mean) * (x - self.new_mean))
33 |
34 | self.old_mean.copy_(self.new_mean)
35 | self.old_std.copy_(self.new_std)
36 |
37 | def mean(self):
38 | return self.new_mean if self.n else torch.zeros_like(self.new_mean)
39 |
40 | def variance(self):
41 | return (self.new_std / (self.n - 1)) if self.n > 1 else torch.zeros_like(self.new_std)
42 |
43 | def rstd(self):
44 | return torch.rsqrt(self.variance() + self.eps)
45 |
46 | def norm(self, x):
47 | return (x - self.mean()) * self.rstd()
48 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/anymal_belief_state_encoder_decoder_pytorch/trainer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.utils.data import Dataset, DataLoader
4 | from torch.optim import Adam
5 |
6 | from collections import deque
7 | from einops import rearrange
8 |
9 | from anymal_belief_state_encoder_decoder_pytorch import Anymal
10 |
11 | class ExperienceDataset(Dataset):
12 | def __init__(self, data):
13 | super().__init__()
14 | self.data = data
15 |
16 | def __len__(self):
17 | return len(self.data[0])
18 |
19 | def __getitem__(self, ind):
20 | return tuple(map(lambda t: t[ind], self.data))
21 |
22 | def create_dataloader(data, batch_size):
23 | ds = ExperienceDataset(data)
24 | return DataLoader(ds, batch_size = batch_size, drop_last = True)
25 |
26 | class StudentTrainer(nn.Module):
27 | def __init__(
28 | self,
29 | *,
30 | anymal,
31 | env,
32 | epochs = 2,
33 | lr = 5e-4,
34 | max_timesteps = 10000,
35 | update_timesteps = 5000,
36 | minibatch_size = 16,
37 | truncate_tpbtt = 10
38 | ):
39 | super().__init__()
40 | self.env = env
41 | self.anymal = anymal
42 | self.optimizer = Adam(anymal.student.parameters(), lr = lr)
43 | self.epochs = epochs
44 |
45 | self.max_timesteps = max_timesteps
46 | self.update_timesteps = update_timesteps
47 | self.minibatch_size = minibatch_size
48 | self.truncate_tpbtt = truncate_tpbtt
49 |
50 | self.running_proprio, self.running_extero = anymal.get_observation_running_stats()
51 |
52 | def learn_from_memories(
53 | self,
54 | memories,
55 | next_states,
56 | noise_strength = 0.
57 | ):
58 | device = next(self.parameters()).device
59 |
60 | # retrieve and prepare data from memory for training
61 |
62 | states = []
63 | teacher_states = []
64 | hiddens = []
65 | dones = []
66 |
67 | for (state, teacher_state, hidden, done) in memories:
68 | states.append(state)
69 | teacher_states.append(teacher_state)
70 | hiddens.append(hidden)
71 | dones.append(torch.Tensor([done]))
72 |
73 | states = tuple(zip(*states))
74 | teacher_states = tuple(zip(*teacher_states))
75 |
76 | # convert values to torch tensors
77 |
78 | to_torch_tensor = lambda t: torch.stack(t).to(device).detach()
79 |
80 | states = map(to_torch_tensor, states)
81 | teacher_states = map(to_torch_tensor, teacher_states)
82 | hiddens = to_torch_tensor(hiddens)
83 | dones = to_torch_tensor(dones)
84 |
85 | # prepare dataloader for policy phase training
86 |
87 | dl = create_dataloader([*states, *teacher_states, hiddens, dones], self.minibatch_size)
88 |
89 | current_hiddens = self.anymal.student.get_gru_hiddens()
90 | current_hiddens = rearrange(current_hiddens, 'l d -> 1 l d')
91 |
92 | for _ in range(self.epochs):
93 | for ind, (proprio, extero, privileged, teacher_proprio, teacher_extero, episode_hiddens, done) in enumerate(dl):
94 |
95 | straight_through_hiddens = current_hiddens - current_hiddens.detach() + episode_hiddens
96 |
97 | loss, current_hiddens = self.anymal(
98 | proprio,
99 | extero,
100 | privileged,
101 | teacher_states = (teacher_proprio, teacher_extero),
102 | hiddens = straight_through_hiddens,
103 | noise_strength = noise_strength
104 | )
105 |
106 | loss.backward(retain_graph = True)
107 |
108 | tbptt_limit = not ((ind + 1) % self.truncate_tpbtt)
109 | if tbptt_limit: # how far back in time should the gradients go for recurrence
110 | self.optimizer.step()
111 | self.optimizer.zero_grad()
112 | current_hiddens = current_hiddens.detach()
113 |
114 | # detacher hiddens depending on whether it is a new episode or not
115 | # todo: restructure dataloader to load one episode per batch rows
116 |
117 | maybe_detached_hiddens = []
118 | for current_hidden, done in zip(current_hiddens.unbind(dim = 0), dones.unbind(dim = 0)):
119 | maybe_detached_hiddens.append(current_hidden.detached() if done else current_hidden)
120 |
121 | current_hiddens = torch.stack(maybe_detached_hiddens)
122 |
123 | def forward(
124 | self,
125 | noise_strength = 0.
126 | ):
127 | device = next(self.parameters()).device
128 |
129 | time = 0
130 | done = False
131 | states = self.env.reset()
132 | memories = deque([])
133 |
134 | hidden = self.anymal.student.get_gru_hiddens()
135 | hidden = rearrange(hidden, 'l d -> 1 l d')
136 |
137 | self.running_proprio.clear()
138 | self.running_extero.clear()
139 |
140 | for timestep in range(self.max_timesteps):
141 | time += 1
142 |
143 | states = list(map(lambda t: t.to(device), states))
144 | anymal_states = list(map(lambda t: rearrange(t, '... -> 1 ...'), states))
145 |
146 | # teacher needs to have normalized observations
147 |
148 | (proprio, extero, privileged) = states
149 |
150 | self.running_proprio.push(proprio)
151 | self.running_extero.push(extero)
152 |
153 | teacher_states = (
154 | self.running_proprio.norm(proprio),
155 | self.running_extero.norm(extero)
156 | )
157 |
158 | teacher_anymal_states = list(map(lambda t: rearrange(t, '... -> 1 ...'), teacher_states))
159 |
160 | # add states to memories
161 |
162 | memories.append((
163 | states,
164 | teacher_states,
165 | rearrange(hidden, '1 ... -> ...'),
166 | done
167 | ))
168 |
169 | dist, hidden = self.anymal.forward_student(
170 | *anymal_states[:-1],
171 | hiddens = hidden,
172 | return_action_categorical_dist = True
173 | )
174 |
175 | action = dist.sample()
176 | action_log_prob = dist.log_prob(action)
177 | action = action.item()
178 |
179 | next_states, _, done, _ = self.env.step(action)
180 |
181 | states = next_states
182 |
183 | if time % self.update_timesteps == 0:
184 | self.learn_from_memories(memories, next_states, noise_strength = noise_strength)
185 | memories.clear()
186 |
187 | if done:
188 | break
189 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Belief State Encoder / Decoder (Anymal) - Pytorch
4 |
5 | Implementation of the Belief State Encoder / Decoder in the new breakthrough robotics paper from ETH Zürich.
6 |
7 | This paper is important as it seems their learned approach produced a policy that rivals Boston Dynamic's handcrafted algorithms (quadripedal Spot).
8 |
9 | The results speak for itself in their video demonstration
10 |
11 | ## Install
12 |
13 | ```bash
14 | $ pip install anymal-belief-state-encoder-decoder-pytorch
15 | ```
16 |
17 | ## Usage
18 |
19 | Teacher
20 |
21 | ```python
22 | import torch
23 | from anymal_belief_state_encoder_decoder_pytorch import Teacher
24 |
25 | teacher = Teacher(
26 | num_actions = 10,
27 | num_legs = 4,
28 | extero_dim = 52,
29 | proprio_dim = 133,
30 | privileged_dim = 50
31 | )
32 |
33 | proprio = torch.randn(1, 133)
34 | extero = torch.randn(1, 4, 52)
35 | privileged = torch.randn(1, 50)
36 |
37 | action_logits, values = teacher(proprio, extero, privileged, return_values = True) # (1, 10)
38 | ```
39 |
40 | Student
41 |
42 | ```python
43 | import torch
44 | from anymal_belief_state_encoder_decoder_pytorch import Student
45 |
46 | student = Student(
47 | num_actions = 10,
48 | num_legs = 4,
49 | extero_dim = 52,
50 | proprio_dim = 133,
51 | gru_num_layers = 2,
52 | gru_hidden_size = 50
53 | )
54 |
55 | proprio = torch.randn(1, 133)
56 | extero = torch.randn(1, 4, 52)
57 |
58 | action_logits, hiddens = student(proprio, extero) # (1, 10), (2, 1, 50)
59 | action_logits, hiddens = student(proprio, extero, hiddens) # (1, 10), (2, 1, 50)
60 | action_logits, hiddens = student(proprio, extero, hiddens) # (1, 10), (2, 1, 50)
61 |
62 | # hiddens are in the shape (num gru layers, batch size, gru hidden dimension)
63 | # train with truncated bptt
64 | ```
65 |
66 | Full Anymal (which contains both Teacher and Student)
67 |
68 | ```python
69 | import torch
70 | from anymal_belief_state_encoder_decoder_pytorch import Anymal
71 |
72 | anymal = Anymal(
73 | num_actions = 10,
74 | num_legs = 4,
75 | extero_dim = 52,
76 | proprio_dim = 133,
77 | privileged_dim = 50,
78 | recon_loss_weight = 0.5
79 | )
80 |
81 | # mock data
82 |
83 | proprio = torch.randn(1, 133)
84 | extero = torch.randn(1, 4, 52)
85 | privileged = torch.randn(1, 50)
86 |
87 | # first train teacher
88 |
89 | teacher_action_logits = anymal.forward_teacher(proprio, extero, privileged)
90 |
91 | # teacher is trained with privileged information in simulation with domain randomization
92 |
93 | # after teacher has satisfactory performance, init the student with the teacher weights, excluding the privilege information encoder from the teacher (which student does not have)
94 |
95 | anymal.init_student_with_teacher()
96 |
97 | # then train the student on the proprioception and noised exteroception, forcing it to reconstruct the privileged information that the teacher had access to (as well as learning to denoise the exterception) - there is also a behavior loss between the policy logits of the teacher with those of the student
98 |
99 | loss, hiddens = anymal(proprio, extero, privileged)
100 | loss.backward()
101 |
102 | # finally, you can deploy the student to the real world, zero-shot
103 |
104 | anymal.eval()
105 | dist, hiddens = anymal.forward_student(proprio, extero, return_action_categorical_dist = True)
106 | action = dist.sample()
107 | ```
108 |
109 | PPO training of the Teacher (using a mock environment, this needs to be substituted with a environment wrapper around simulator)
110 |
111 | ```python
112 | import torch
113 | from anymal_belief_state_encoder_decoder_pytorch import Anymal, PPO
114 | from anymal_belief_state_encoder_decoder_pytorch.ppo import MockEnv
115 |
116 | anymal = Anymal(
117 | num_actions = 10,
118 | num_legs = 4,
119 | extero_dim = 52,
120 | proprio_dim = 133,
121 | privileged_dim = 50,
122 | recon_loss_weight = 0.5
123 | )
124 |
125 | mock_env = MockEnv(
126 | proprio_dim = 133,
127 | extero_dim = 52,
128 | privileged_dim = 50
129 | )
130 |
131 | ppo = PPO(
132 | env = mock_env,
133 | anymal = anymal,
134 | epochs = 10,
135 | lr = 3e-4,
136 | eps_clip = 0.2,
137 | beta_s = 0.01,
138 | value_clip = 0.4,
139 | max_timesteps = 10000,
140 | update_timesteps = 5000,
141 | )
142 |
143 | # train for 10 episodes
144 |
145 | for _ in range(10):
146 | ppo()
147 |
148 | # save the weights of the teacher for student training
149 |
150 | torch.save(anymal.state_dict(), './anymal-with-trained-teacher.pt')
151 | ```
152 |
153 | To train the student
154 |
155 | ```python
156 | import torch
157 | from anymal_belief_state_encoder_decoder_pytorch import Anymal
158 | from anymal_belief_state_encoder_decoder_pytorch.trainer import StudentTrainer
159 | from anymal_belief_state_encoder_decoder_pytorch.ppo import MockEnv
160 |
161 | anymal = Anymal(
162 | num_actions = 10,
163 | num_legs = 4,
164 | extero_dim = 52,
165 | proprio_dim = 133,
166 | privileged_dim = 50,
167 | recon_loss_weight = 0.5
168 | )
169 |
170 | # first init student with teacher weights, at the very beginning
171 | # if not resuming training
172 |
173 | mock_env = MockEnv(
174 | proprio_dim = 133,
175 | extero_dim = 52,
176 | privileged_dim = 50
177 | )
178 |
179 | trainer = StudentTrainer(
180 | anymal = anymal,
181 | env = mock_env
182 | )
183 |
184 | # for 100 episodes
185 |
186 | for _ in range(100):
187 | trainer()
188 |
189 | ```
190 |
191 | ... You've beaten Boston Dynamics and its team of highly paid control engineers!
192 |
193 | But you probably haven't beaten a real quadripedal "anymal" just yet :)
194 |
195 |
196 |
197 |
198 | ## Todo
199 |
200 | - [x] finish belief state decoder
201 | - [x] wrapper class that instantiates both teacher and student, handle student forward pass with reconstruction loss + behavioral loss
202 | - [x] handle noising of exteroception for student
203 | - [x] add basic PPO logic for teacher
204 | - [x] add basic student training loop with mock environment
205 | - [x] make sure all hyperparameters for teacher PPO training + teacher / student distillation is in accordance with appendix
206 | - [ ] noise scheduler for student (curriculum factor that goes from 0 to 1 from epochs 1 to 100)
207 | - [ ] fix student training, it does not look correct
208 | - [ ] make sure tbptt is setup correctly
209 | - [ ] add reward crafting as in paper
210 | - [ ] play around with deepminds mujoco
211 |
212 | ## Diagrams
213 |
214 |
215 |
216 | ## Citations
217 |
218 | ```bibtex
219 | @article{2022,
220 | title = {Learning robust perceptive locomotion for quadrupedal robots in the wild},
221 | url = {http://dx.doi.org/10.1126/scirobotics.abk2822},
222 | journal = {Science Robotics},
223 | publisher = {American Association for the Advancement of Science (AAAS)},
224 | author = {Miki, Takahiro and Lee, Joonho and Hwangbo, Jemin and Wellhausen, Lorenz and Koltun, Vladlen and Hutter, Marco},
225 | year = {2022},
226 | month = {Jan}
227 | }
228 | ```
229 |
--------------------------------------------------------------------------------
/anymal_belief_state_encoder_decoder_pytorch/ppo.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple, deque
2 |
3 | import torch
4 | from torch import nn, cat, stack
5 | from torch.nn import Module
6 | import torch.nn.functional as F
7 | from torch.utils.data import Dataset, DataLoader, TensorDataset
8 | from torch.optim import Adam
9 |
10 | from assoc_scan import AssocScan
11 |
12 | from anymal_belief_state_encoder_decoder_pytorch import Anymal
13 | from anymal_belief_state_encoder_decoder_pytorch.networks import unfreeze_all_layers_
14 |
15 | from einops import rearrange
16 |
17 | # helper functions
18 |
19 | def exists(val):
20 | return val is not None
21 |
22 | def default(v, d):
23 | return v if exists(v) else d
24 |
25 | # they use basic PPO for training the teacher with privileged information
26 | # then they used noisy student training, using the trained "oracle" teacher as guide
27 |
28 | # ppo data
29 |
30 | Memory = namedtuple('Memory', ['state', 'action', 'action_log_prob', 'reward', 'done', 'value'])
31 |
32 | def create_shuffled_dataloader(data, batch_size):
33 | ds = TensorDataset(*data)
34 | return DataLoader(ds, batch_size = batch_size, shuffle = True)
35 |
36 | # ppo helper functions
37 |
38 | def normalize(t, eps = 1e-5):
39 | return (t - t.mean()) / (t.std() + eps)
40 |
41 | # generalized advantage estimate
42 |
43 | def calc_generalized_advantage_estimate(
44 | rewards,
45 | values,
46 | masks,
47 | gamma = 0.99,
48 | lam = 0.95,
49 | use_accelerated = None
50 | ):
51 | device, is_cuda = rewards.device, rewards.is_cuda
52 | use_accelerated = default(use_accelerated, is_cuda)
53 |
54 | values, values_next = values[:-1], values[1:]
55 |
56 | delta = rewards + gamma * values_next * masks - values
57 | gates = gamma * lam * masks
58 |
59 | scan = AssocScan(reverse = True, use_accelerated = use_accelerated)
60 |
61 | gae = scan(gates, delta)
62 | return gae + values
63 |
64 | def clipped_value_loss(values, rewards, old_values, clip):
65 | value_clipped = old_values + (values - old_values).clamp(-clip, clip)
66 | value_loss_1 = (value_clipped.flatten() - rewards) ** 2
67 | value_loss_2 = (values.flatten() - rewards) ** 2
68 | return torch.mean(torch.max(value_loss_1, value_loss_2))
69 |
70 | # mock environment
71 |
72 | class MockEnv(object):
73 | def __init__(
74 | self,
75 | proprio_dim,
76 | extero_dim,
77 | privileged_dim,
78 | num_legs = 4
79 | ):
80 | self.proprio_dim = proprio_dim
81 | self.extero_dim = extero_dim
82 | self.privileged_dim = privileged_dim
83 | self.num_legs = num_legs
84 |
85 | def rand_state(self):
86 | return (
87 | torch.randn((self.proprio_dim,)),
88 | torch.randn((self.num_legs, self.extero_dim,)),
89 | torch.randn((self.privileged_dim,))
90 | )
91 |
92 | def reset(self):
93 | return self.rand_state()
94 |
95 | def step(self, action):
96 | reward = torch.randn((1,))
97 | done = torch.tensor([False])
98 | return self.rand_state(), reward, done, None
99 |
100 | # main ppo class
101 |
102 | class PPO(Module):
103 | def __init__(
104 | self,
105 | *,
106 | env,
107 | anymal,
108 | epochs = 2,
109 | lr = 5e-4,
110 | betas = (0.9, 0.999),
111 | eps_clip = 0.2,
112 | beta_s = 0.005,
113 | value_clip = 0.4,
114 | max_timesteps = 10000,
115 | update_timesteps = 5000,
116 | lam = 0.95,
117 | gamma = 0.99,
118 | minibatch_size = 8300
119 | ):
120 | super().__init__()
121 | assert isinstance(anymal, Anymal)
122 | self.env = env
123 | self.anymal = anymal
124 |
125 | self.minibatch_size = minibatch_size
126 | self.optimizer = Adam(anymal.teacher.parameters(), lr = lr, betas = betas)
127 | self.epochs = epochs
128 |
129 | self.max_timesteps = max_timesteps
130 | self.update_timesteps = update_timesteps
131 |
132 | self.beta_s = beta_s
133 | self.eps_clip = eps_clip
134 | self.value_clip = value_clip
135 |
136 | self.lam = lam
137 | self.gamma = gamma
138 |
139 | # in paper, they said observations fed to teacher were normalized
140 | # by running mean
141 |
142 | self.running_proprio, self.running_extero = anymal.get_observation_running_stats()
143 |
144 | def learn_from_memories(
145 | self,
146 | memories,
147 | next_states
148 | ):
149 | device = next(self.parameters()).device
150 |
151 | # retrieve and prepare data from memory for training
152 |
153 | (
154 | states,
155 | actions,
156 | old_log_probs,
157 | rewards,
158 | dones,
159 | values
160 | ) = tuple(zip(*memories))
161 |
162 | states = tuple(zip(*states))
163 |
164 | # calculate generalized advantage estimate
165 |
166 | rewards = cat(rewards).to(device)
167 | values = cat(values).to(device).detach()
168 | masks = 1. - cat(dones).to(device).float()
169 |
170 | next_states = [t.to(device) for t in next_states]
171 | next_states = [rearrange(t, '... -> 1 ...') for t in next_states]
172 |
173 | with torch.no_grad():
174 | self.anymal.eval()
175 | _, next_value = self.anymal.forward_teacher(*next_states, return_value_head = True)
176 | next_value = next_value.detach()
177 |
178 | values_with_next = cat((values, next_value))
179 |
180 | returns = calc_generalized_advantage_estimate(rewards, values_with_next, masks, self.gamma, self.lam).detach()
181 |
182 | # convert values to torch tensors
183 |
184 | to_torch_tensor = lambda t: stack(t).to(device).detach()
185 |
186 | states = map(to_torch_tensor, states)
187 | actions = to_torch_tensor(actions)
188 | old_log_probs = to_torch_tensor(old_log_probs)
189 |
190 | # prepare dataloader for policy phase training
191 |
192 | dl = create_shuffled_dataloader([*states, actions, old_log_probs, rewards, values], self.minibatch_size)
193 |
194 | # policy phase training, similar to original PPO
195 |
196 | for _ in range(self.epochs):
197 | for proprio, extero, privileged, actions, old_log_probs, rewards, old_values in dl:
198 |
199 | dist, values = self.anymal.forward_teacher(
200 | proprio, extero, privileged,
201 | return_value_head = True,
202 | return_action_categorical_dist = True
203 | )
204 |
205 | action_log_probs = dist.log_prob(actions)
206 |
207 | entropy = dist.entropy()
208 | ratios = (action_log_probs - old_log_probs).exp()
209 | advantages = normalize(rewards - old_values.detach())
210 | surr1 = ratios * advantages
211 | surr2 = ratios.clamp(1 - self.eps_clip, 1 + self.eps_clip) * advantages
212 |
213 | policy_loss = - torch.min(surr1, surr2) - self.beta_s * entropy
214 |
215 | value_loss = clipped_value_loss(values, rewards, old_values, self.value_clip)
216 |
217 | (policy_loss.mean() + value_loss.mean()).backward()
218 | self.optimizer.step()
219 | self.optimizer.zero_grad()
220 |
221 | # does one episodes worth of learning
222 |
223 | def forward(self):
224 | device = next(self.parameters()).device
225 | unfreeze_all_layers_(self.anymal)
226 |
227 | time = 0
228 | states = self.env.reset() # states assumed to be (proprioception, exteroception, privileged information)
229 | memories = deque([])
230 |
231 | self.running_proprio.clear()
232 | self.running_extero.clear()
233 |
234 | for timestep in range(self.max_timesteps):
235 | time += 1
236 |
237 | states = list(map(lambda t: t.to(device), states))
238 | proprio, extero, privileged = states
239 |
240 | # update running means for observations, for teacher
241 |
242 | self.running_proprio.push(proprio)
243 | self.running_extero.push(extero)
244 |
245 | # normalize observation states for teacher (proprio and extero)
246 |
247 | states = (
248 | self.running_proprio.norm(proprio),
249 | self.running_extero.norm(extero),
250 | privileged
251 | )
252 |
253 | anymal_states = list(map(lambda t: rearrange(t, '... -> 1 ...'), states))
254 |
255 | dist, values = self.anymal.forward_teacher(
256 | *anymal_states,
257 | return_value_head = True,
258 | return_action_categorical_dist = True
259 | )
260 |
261 | action = dist.sample()
262 | action_log_prob = dist.log_prob(action)
263 |
264 | next_states, reward, done, _ = self.env.step(action)
265 |
266 | memory = Memory(states, action, action_log_prob, reward, done, values)
267 | memories.append(memory)
268 |
269 | states = next_states
270 |
271 | if time % self.update_timesteps == 0:
272 | self.learn_from_memories(memories, next_states)
273 | memories.clear()
274 |
275 | if done:
276 | break
277 |
278 | print('trained for 1 episode')
279 |
--------------------------------------------------------------------------------
/anymal_belief_state_encoder_decoder_pytorch/networks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | from torch.nn import GRUCell
5 | from torch.distributions import Categorical
6 | from torch.optim import Adam
7 |
8 | from einops import rearrange
9 | from einops.layers.torch import Rearrange
10 |
11 | from anymal_belief_state_encoder_decoder_pytorch.running import RunningStats
12 |
13 | # helper functions
14 |
15 | def exists(val):
16 | return val is not None
17 |
18 | def check_shape(tensor, pattern, **kwargs):
19 | return rearrange(tensor, f'{pattern} -> {pattern}', **kwargs)
20 |
21 | # freezing of neural networks (teacher needs to be frozen)
22 |
23 | def set_module_requires_grad_(module, requires_grad):
24 | for param in module.parameters():
25 | param.requires_grad = requires_grad
26 |
27 | def freeze_all_layers_(module):
28 | set_module_requires_grad_(module, False)
29 |
30 | def unfreeze_all_layers_(module):
31 | set_module_requires_grad_(module, True)
32 |
33 | # in the paper
34 | # the network attention gates the exteroception, and then sums it to the belief state
35 | # todo: make sure the padding is on the right side
36 |
37 | def sum_with_zeropad(x, y):
38 | x_dim, y_dim = x.shape[-1], y.shape[-1]
39 |
40 | if x_dim == y_dim:
41 | return x + y
42 |
43 | if x_dim < y_dim:
44 | x = F.pad(x, (y_dim - x_dim, 0))
45 |
46 | if y_dim < x_dim:
47 | y = F.pad(y, (x_dim - y_dim, 0))
48 |
49 | return x + y
50 |
51 | # add basic MLP
52 |
53 | class MLP(nn.Module):
54 | def __init__(
55 | self,
56 | dims,
57 | activation = nn.LeakyReLU,
58 | final_activation = False
59 | ):
60 | super().__init__()
61 | assert isinstance(dims, (list, tuple))
62 | assert len(dims) > 2, 'must have at least 3 dimensions (input, *hiddens, output)'
63 |
64 | dim_pairs = list(zip(dims[:-1], dims[1:]))
65 | *dim_pairs, dim_out_pair = dim_pairs
66 |
67 | layers = []
68 | for dim_in, dim_out in dim_pairs:
69 | layers.extend([
70 | nn.Linear(dim_in, dim_out),
71 | activation()
72 | ])
73 |
74 | layers.append(nn.Linear(*dim_out_pair))
75 |
76 | if final_activation:
77 | layers.append(activation())
78 |
79 | self.net = nn.Sequential(*layers)
80 |
81 | def forward(self, x):
82 | if isinstance(x, (tuple, list)):
83 | x = torch.cat(x, dim = -1)
84 |
85 | return self.net(x)
86 |
87 | class Student(nn.Module):
88 | def __init__(
89 | self,
90 | num_actions,
91 | proprio_dim = 133,
92 | extero_dim = 52, # in paper, height samples was marked as 208, but wasn't sure if that was per leg, or (4 legs x 52) = 208
93 | latent_extero_dim = 24,
94 | extero_encoder_hidden = (80, 60),
95 | belief_state_encoder_hiddens = (64, 64),
96 | extero_gate_encoder_hiddens = (64, 64),
97 | belief_state_dim = 120, # should be equal to teacher's extero_dim + privileged_dim (part of the GRU's responsibility is to maintain a hidden state that forms an opinion on the privileged information)
98 | gru_num_layers = 2,
99 | gru_hidden_size = 50,
100 | mlp_hidden = (256, 160, 128),
101 | num_legs = 4,
102 | privileged_dim = 50,
103 | privileged_decoder_hiddens = (64, 64),
104 | extero_decoder_hiddens = (64, 64),
105 | ):
106 | super().__init__()
107 | assert belief_state_dim > (num_legs * latent_extero_dim)
108 | self.num_legs = num_legs
109 | self.proprio_dim = proprio_dim
110 | self.extero_dim = extero_dim
111 |
112 | # encoding of exteroception
113 |
114 | self.extero_encoder = MLP((extero_dim, *extero_encoder_hidden, latent_extero_dim))
115 |
116 | # GRU related parameters
117 |
118 | gru_input_dim = (latent_extero_dim * num_legs) + proprio_dim
119 | gru_input_dims = (gru_input_dim, *((gru_hidden_size,) * (gru_num_layers - 1)))
120 | self.gru_cells = nn.ModuleList([GRUCell(input_dim, gru_hidden_size) for input_dim in gru_input_dims])
121 | self.gru_hidden_size = gru_hidden_size
122 |
123 | # belief state encoding
124 |
125 | self.belief_state_encoder = MLP((gru_hidden_size, *belief_state_encoder_hiddens, belief_state_dim))
126 |
127 | # attention gating of exteroception
128 |
129 | self.to_latent_extero_attn_gate = MLP((gru_hidden_size, *extero_gate_encoder_hiddens, latent_extero_dim * num_legs))
130 |
131 | # belief state decoder
132 |
133 | self.privileged_decoder = MLP((gru_hidden_size, *privileged_decoder_hiddens, privileged_dim))
134 | self.extero_decoder = MLP((gru_hidden_size, *extero_decoder_hiddens, extero_dim * num_legs))
135 |
136 | self.to_extero_attn_gate = MLP((gru_hidden_size, *extero_gate_encoder_hiddens, extero_dim * num_legs))
137 |
138 | # final MLP to action logits
139 |
140 | self.to_logits = MLP((
141 | belief_state_dim + proprio_dim,
142 | *mlp_hidden
143 | ))
144 |
145 | self.to_action_head = nn.Sequential(
146 | nn.LeakyReLU(),
147 | nn.Linear(mlp_hidden[-1], num_actions)
148 | )
149 |
150 | def get_gru_hiddens(self):
151 | device = next(self.parameters()).device
152 | return torch.zeros((len(self.gru_cells), self.gru_hidden_size))
153 |
154 | def forward(
155 | self,
156 | proprio,
157 | extero,
158 | hiddens = None,
159 | return_estimated_info = False, # for returning estimated privileged info + exterceptive info, for reconstruction loss
160 | return_action_categorical_dist = False
161 | ):
162 | check_shape(proprio, 'b d', d = self.proprio_dim)
163 | check_shape(extero, 'b n d', n = self.num_legs, d = self.extero_dim)
164 |
165 | latent_extero = self.extero_encoder(extero)
166 | latent_extero = rearrange(latent_extero, 'b ... -> b (...)')
167 |
168 | # RNN
169 |
170 | if not exists(hiddens):
171 | prev_hiddens = (None,) * len(self.gru_cells)
172 | else:
173 | prev_hiddens = hiddens.unbind(dim = -2)
174 |
175 | gru_input = torch.cat((proprio, latent_extero), dim = -1)
176 |
177 | next_hiddens = []
178 | for gru_cell, prev_hidden in zip(self.gru_cells, prev_hiddens):
179 | gru_input = gru_cell(gru_input, prev_hidden)
180 | next_hiddens.append(gru_input)
181 |
182 | gru_output = gru_input
183 |
184 | next_hiddens = torch.stack(next_hiddens, dim = -2)
185 |
186 | # attention gating of exteroception
187 |
188 | latent_extero_attn_gate = self.to_latent_extero_attn_gate(gru_output)
189 | gated_latent_extero = latent_extero * latent_extero_attn_gate.sigmoid()
190 |
191 | # belief state and add gated exteroception
192 |
193 | belief_state = self.belief_state_encoder(gru_output)
194 | belief_state = sum_with_zeropad(belief_state, gated_latent_extero)
195 |
196 | # to action logits
197 |
198 | belief_state_with_proprio = torch.cat((
199 | proprio,
200 | belief_state,
201 | ), dim = 1)
202 |
203 | logits = self.to_logits(belief_state_with_proprio)
204 |
205 | pi_logits = self.to_action_head(logits)
206 |
207 | return_action = Categorical(pi_logits.softmax(dim = -1)) if return_action_categorical_dist else pi_logits
208 |
209 | if not return_estimated_info:
210 | return return_action, next_hiddens
211 |
212 | # belief state decoding
213 | # for reconstructing privileged and exteroception information from hidden belief states
214 |
215 | recon_privileged = self.privileged_decoder(gru_output)
216 | recon_extero = self.extero_decoder(gru_output)
217 | extero_attn_gate = self.to_extero_attn_gate(gru_output)
218 |
219 | gated_extero = rearrange(extero, 'b ... -> b (...)') * extero_attn_gate.sigmoid()
220 | recon_extero = recon_extero + gated_extero
221 | recon_extero = rearrange(recon_extero, 'b (n d) -> b n d', n = self.num_legs)
222 |
223 | # whether to return raw policy logits or action probs wrapped with Categorical
224 |
225 | return return_action, next_hiddens, (recon_privileged, recon_extero)
226 |
227 | class Teacher(nn.Module):
228 | def __init__(
229 | self,
230 | num_actions,
231 | proprio_dim = 133,
232 | extero_dim = 52, # in paper, height samples was marked as 208, but wasn't sure if that was per leg, or (4 legs x 52) = 208
233 | latent_extero_dim = 24,
234 | extero_encoder_hidden = (80, 60),
235 | privileged_dim = 50,
236 | latent_privileged_dim = 24,
237 | privileged_encoder_hidden = (64, 32),
238 | mlp_hidden = (256, 160, 128),
239 | num_legs = 4
240 | ):
241 | super().__init__()
242 | self.num_legs = num_legs
243 | self.proprio_dim = proprio_dim
244 | self.extero_dim = extero_dim
245 | self.privileged_dim = privileged_dim
246 |
247 | self.extero_encoder = MLP((extero_dim, *extero_encoder_hidden, latent_extero_dim))
248 | self.privileged_encoder = MLP((privileged_dim, *privileged_encoder_hidden, latent_privileged_dim))
249 |
250 | self.to_logits = MLP((
251 | latent_extero_dim * num_legs + latent_privileged_dim + proprio_dim,
252 | *mlp_hidden
253 | ))
254 |
255 | self.to_action_head = nn.Sequential(
256 | nn.LeakyReLU(),
257 | nn.Linear(mlp_hidden[-1], num_actions)
258 | )
259 |
260 | self.to_value_head = nn.Sequential(
261 | nn.LeakyReLU(),
262 | nn.Linear(mlp_hidden[-1], 1),
263 | Rearrange('... 1 -> ...')
264 | )
265 |
266 | def forward(
267 | self,
268 | proprio,
269 | extero,
270 | privileged,
271 | return_value_head = False,
272 | return_action_categorical_dist = False
273 | ):
274 | check_shape(proprio, 'b d', d = self.proprio_dim)
275 | check_shape(extero, 'b n d', n = self.num_legs, d = self.extero_dim)
276 | check_shape(privileged, 'b d', d = self.privileged_dim)
277 |
278 | latent_extero = self.extero_encoder(extero)
279 | latent_extero = rearrange(latent_extero, 'b ... -> b (...)')
280 |
281 | latent_privileged = self.privileged_encoder(privileged)
282 |
283 | latent = torch.cat((
284 | proprio,
285 | latent_extero,
286 | latent_privileged,
287 | ), dim = -1)
288 |
289 | logits = self.to_logits(latent)
290 |
291 | pi_logits = self.to_action_head(logits)
292 |
293 | if not return_value_head:
294 | return pi_logits
295 |
296 | value_logits = self.to_value_head(logits)
297 |
298 | return_action = Categorical(pi_logits.softmax(dim = -1)) if return_action_categorical_dist else pi_logits
299 | return return_action, value_logits
300 |
301 | # manages both teacher and student under one module
302 |
303 | class Anymal(nn.Module):
304 | def __init__(
305 | self,
306 | num_actions,
307 | proprio_dim = 133,
308 | extero_dim = 52,
309 | privileged_dim = 50,
310 | num_legs = 4,
311 | latent_extero_dim = 24,
312 | latent_privileged_dim = 24,
313 | teacher_extero_encoder_hidden = (80, 60),
314 | teacher_privileged_encoder_hidden = (64, 32),
315 | student_extero_gate_encoder_hiddens = (64, 64),
316 | student_belief_state_encoder_hiddens = (64, 64),
317 | student_belief_state_dim = 120,
318 | student_gru_num_layers = 2,
319 | student_gru_hidden_size = 50,
320 | student_privileged_decoder_hiddens = (64, 64),
321 | student_extero_decoder_hiddens = (64, 64),
322 | student_extero_encoder_hidden = (80, 60),
323 | mlp_hidden = (256, 160, 128),
324 | recon_loss_weight = 0.5
325 | ):
326 | super().__init__()
327 | self.proprio_dim = proprio_dim
328 | self.num_legs = num_legs
329 | self.extero_dim = extero_dim
330 |
331 | self.student = Student(
332 | num_actions = num_actions,
333 | proprio_dim = proprio_dim,
334 | extero_dim = extero_dim,
335 | latent_extero_dim = latent_extero_dim,
336 | extero_encoder_hidden = student_extero_encoder_hidden,
337 | belief_state_encoder_hiddens = student_belief_state_encoder_hiddens,
338 | extero_gate_encoder_hiddens = student_extero_gate_encoder_hiddens,
339 | belief_state_dim = student_belief_state_dim,
340 | gru_num_layers = student_gru_num_layers,
341 | gru_hidden_size = student_gru_hidden_size,
342 | mlp_hidden = mlp_hidden,
343 | num_legs = num_legs,
344 | privileged_dim = privileged_dim,
345 | privileged_decoder_hiddens = student_privileged_decoder_hiddens,
346 | extero_decoder_hiddens = student_extero_decoder_hiddens,
347 | )
348 |
349 | self.teacher = Teacher(
350 | num_actions = num_actions,
351 | proprio_dim = proprio_dim,
352 | extero_dim = extero_dim,
353 | latent_extero_dim = latent_extero_dim,
354 | extero_encoder_hidden = teacher_extero_encoder_hidden,
355 | privileged_dim = privileged_dim,
356 | latent_privileged_dim = latent_privileged_dim,
357 | privileged_encoder_hidden = teacher_privileged_encoder_hidden,
358 | mlp_hidden = mlp_hidden,
359 | num_legs = num_legs
360 | )
361 |
362 | self.recon_loss_weight = recon_loss_weight
363 |
364 | def get_observation_running_stats(self):
365 | return RunningStats(self.proprio_dim), RunningStats((self.num_legs, self.extero_dim))
366 |
367 | def init_student_with_teacher(self):
368 | self.student.extero_encoder.load_state_dict(self.teacher.extero_encoder.state_dict())
369 | self.student.to_logits.load_state_dict(self.teacher.to_logits.state_dict())
370 | self.student.to_action_head.load_state_dict(self.teacher.to_action_head.state_dict())
371 |
372 | def forward_teacher(self, *args, return_value_head = False, **kwargs):
373 | return self.teacher(*args, return_value_head = return_value_head, **kwargs)
374 |
375 | def forward_student(self, *args, **kwargs):
376 | return self.student(*args, **kwargs)
377 |
378 | # main forward for training the student with teacher as guide
379 |
380 | def forward(
381 | self,
382 | proprio,
383 | extero,
384 | privileged,
385 | teacher_states = None,
386 | hiddens = None,
387 | noise_strength = 0.1
388 | ):
389 | self.teacher.eval()
390 | freeze_all_layers_(self.teacher)
391 |
392 | with torch.no_grad():
393 | teacher_proprio, teacher_extero = teacher_states if exists(teacher_states) else (proprio, extero)
394 | teacher_action_logits = self.forward_teacher(teacher_proprio, teacher_extero, privileged)
395 |
396 | noised_extero = extero + torch.rand_like(extero) * noise_strength
397 |
398 | student_action_logits, hiddens, recons = self.student(proprio, noised_extero, hiddens = hiddens, return_estimated_info = True)
399 |
400 | # calculate reconstruction loss of privileged and denoised exteroception
401 |
402 | (recon_privileged, recon_extero) = recons
403 | recon_loss = F.mse_loss(recon_privileged, privileged) + F.mse_loss(recon_extero, extero)
404 |
405 | # calculate behavior loss, which is also squared distance?
406 |
407 | behavior_loss = F.mse_loss(teacher_action_logits, student_action_logits) # why not kl div on action probs?
408 |
409 | loss = behavior_loss + recon_loss * self.recon_loss_weight
410 | return loss, hiddens
411 |
--------------------------------------------------------------------------------