├── real-anymal.png ├── anymal-beliefs.png ├── anymal-teacher-student.png ├── anymal_belief_state_encoder_decoder_pytorch ├── __init__.py ├── running.py ├── trainer.py ├── ppo.py └── networks.py ├── setup.py ├── LICENSE ├── .github └── workflows │ └── python-publish.yml ├── .gitignore └── README.md /real-anymal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/anymal-belief-state-encoder-decoder-pytorch/HEAD/real-anymal.png -------------------------------------------------------------------------------- /anymal-beliefs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/anymal-belief-state-encoder-decoder-pytorch/HEAD/anymal-beliefs.png -------------------------------------------------------------------------------- /anymal-teacher-student.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/anymal-belief-state-encoder-decoder-pytorch/HEAD/anymal-teacher-student.png -------------------------------------------------------------------------------- /anymal_belief_state_encoder_decoder_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from anymal_belief_state_encoder_decoder_pytorch.networks import Student, Teacher, MLP, Anymal 2 | from anymal_belief_state_encoder_decoder_pytorch.ppo import PPO, MockEnv 3 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name = 'anymal-belief-state-encoder-decoder-pytorch', 5 | packages = find_packages(exclude=[]), 6 | version = '0.0.21', 7 | license='MIT', 8 | description = 'Anymal Belief-state Encoder Decoder - Pytorch', 9 | author = 'Phil Wang', 10 | author_email = 'lucidrains@gmail.com', 11 | url = 'https://github.com/lucidrains/anymal-belief-state-encoder-decoder-pytorch', 12 | keywords = [ 13 | 'artificial intelligence', 14 | 'deep learning', 15 | 'attention gating', 16 | 'belief state', 17 | 'robotics' 18 | ], 19 | install_requires=[ 20 | 'assoc-scan>=0.0.2', 21 | 'einops>=0.8', 22 | 'evolutionary-policy-optimization>=0.0.61', 23 | 'torch>=2.2', 24 | ], 25 | classifiers=[ 26 | 'Development Status :: 4 - Beta', 27 | 'Intended Audience :: Developers', 28 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 29 | 'License :: OSI Approved :: MIT License', 30 | 'Programming Language :: Python :: 3.6', 31 | ], 32 | ) 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | 2 | 3 | # This workflow will upload a Python Package using Twine when a release is created 4 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 5 | 6 | # This workflow uses actions that are not certified by GitHub. 7 | # They are provided by a third-party and are governed by 8 | # separate terms of service, privacy policy, and support 9 | # documentation. 10 | 11 | name: Upload Python Package 12 | 13 | on: 14 | release: 15 | types: [published] 16 | 17 | jobs: 18 | deploy: 19 | 20 | runs-on: ubuntu-latest 21 | 22 | steps: 23 | - uses: actions/checkout@v2 24 | - name: Set up Python 25 | uses: actions/setup-python@v2 26 | with: 27 | python-version: '3.x' 28 | - name: Install dependencies 29 | run: | 30 | python -m pip install --upgrade pip 31 | pip install build 32 | - name: Build package 33 | run: python -m build 34 | - name: Publish package 35 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 36 | with: 37 | user: __token__ 38 | password: ${{ secrets.PYPI_API_TOKEN }} 39 | -------------------------------------------------------------------------------- /anymal_belief_state_encoder_decoder_pytorch/running.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | class RunningStats(nn.Module): 5 | def __init__(self, shape, eps = 1e-5): 6 | super().__init__() 7 | shape = shape if isinstance(shape, tuple) else (shape,) 8 | 9 | self.shape = shape 10 | self.eps = eps 11 | self.n = 0 12 | 13 | self.register_buffer('old_mean', torch.zeros(shape), persistent = False) 14 | self.register_buffer('new_mean', torch.zeros(shape), persistent = False) 15 | self.register_buffer('old_std', torch.zeros(shape), persistent = False) 16 | self.register_buffer('new_std', torch.zeros(shape), persistent = False) 17 | 18 | def clear(self): 19 | self.n = 0 20 | 21 | def push(self, x): 22 | self.n += 1 23 | 24 | if self.n == 1: 25 | self.old_mean.copy_(x.data) 26 | self.new_mean.copy_(x.data) 27 | self.old_std.zero_() 28 | self.new_std.zero_() 29 | return 30 | 31 | self.new_mean.copy_(self.old_mean + (x - self.old_mean) / self.n) 32 | self.new_std.copy_(self.old_std + (x - self.old_mean) * (x - self.new_mean)) 33 | 34 | self.old_mean.copy_(self.new_mean) 35 | self.old_std.copy_(self.new_std) 36 | 37 | def mean(self): 38 | return self.new_mean if self.n else torch.zeros_like(self.new_mean) 39 | 40 | def variance(self): 41 | return (self.new_std / (self.n - 1)) if self.n > 1 else torch.zeros_like(self.new_std) 42 | 43 | def rstd(self): 44 | return torch.rsqrt(self.variance() + self.eps) 45 | 46 | def norm(self, x): 47 | return (x - self.mean()) * self.rstd() 48 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /anymal_belief_state_encoder_decoder_pytorch/trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.utils.data import Dataset, DataLoader 4 | from torch.optim import Adam 5 | 6 | from collections import deque 7 | from einops import rearrange 8 | 9 | from anymal_belief_state_encoder_decoder_pytorch import Anymal 10 | 11 | class ExperienceDataset(Dataset): 12 | def __init__(self, data): 13 | super().__init__() 14 | self.data = data 15 | 16 | def __len__(self): 17 | return len(self.data[0]) 18 | 19 | def __getitem__(self, ind): 20 | return tuple(map(lambda t: t[ind], self.data)) 21 | 22 | def create_dataloader(data, batch_size): 23 | ds = ExperienceDataset(data) 24 | return DataLoader(ds, batch_size = batch_size, drop_last = True) 25 | 26 | class StudentTrainer(nn.Module): 27 | def __init__( 28 | self, 29 | *, 30 | anymal, 31 | env, 32 | epochs = 2, 33 | lr = 5e-4, 34 | max_timesteps = 10000, 35 | update_timesteps = 5000, 36 | minibatch_size = 16, 37 | truncate_tpbtt = 10 38 | ): 39 | super().__init__() 40 | self.env = env 41 | self.anymal = anymal 42 | self.optimizer = Adam(anymal.student.parameters(), lr = lr) 43 | self.epochs = epochs 44 | 45 | self.max_timesteps = max_timesteps 46 | self.update_timesteps = update_timesteps 47 | self.minibatch_size = minibatch_size 48 | self.truncate_tpbtt = truncate_tpbtt 49 | 50 | self.running_proprio, self.running_extero = anymal.get_observation_running_stats() 51 | 52 | def learn_from_memories( 53 | self, 54 | memories, 55 | next_states, 56 | noise_strength = 0. 57 | ): 58 | device = next(self.parameters()).device 59 | 60 | # retrieve and prepare data from memory for training 61 | 62 | states = [] 63 | teacher_states = [] 64 | hiddens = [] 65 | dones = [] 66 | 67 | for (state, teacher_state, hidden, done) in memories: 68 | states.append(state) 69 | teacher_states.append(teacher_state) 70 | hiddens.append(hidden) 71 | dones.append(torch.Tensor([done])) 72 | 73 | states = tuple(zip(*states)) 74 | teacher_states = tuple(zip(*teacher_states)) 75 | 76 | # convert values to torch tensors 77 | 78 | to_torch_tensor = lambda t: torch.stack(t).to(device).detach() 79 | 80 | states = map(to_torch_tensor, states) 81 | teacher_states = map(to_torch_tensor, teacher_states) 82 | hiddens = to_torch_tensor(hiddens) 83 | dones = to_torch_tensor(dones) 84 | 85 | # prepare dataloader for policy phase training 86 | 87 | dl = create_dataloader([*states, *teacher_states, hiddens, dones], self.minibatch_size) 88 | 89 | current_hiddens = self.anymal.student.get_gru_hiddens() 90 | current_hiddens = rearrange(current_hiddens, 'l d -> 1 l d') 91 | 92 | for _ in range(self.epochs): 93 | for ind, (proprio, extero, privileged, teacher_proprio, teacher_extero, episode_hiddens, done) in enumerate(dl): 94 | 95 | straight_through_hiddens = current_hiddens - current_hiddens.detach() + episode_hiddens 96 | 97 | loss, current_hiddens = self.anymal( 98 | proprio, 99 | extero, 100 | privileged, 101 | teacher_states = (teacher_proprio, teacher_extero), 102 | hiddens = straight_through_hiddens, 103 | noise_strength = noise_strength 104 | ) 105 | 106 | loss.backward(retain_graph = True) 107 | 108 | tbptt_limit = not ((ind + 1) % self.truncate_tpbtt) 109 | if tbptt_limit: # how far back in time should the gradients go for recurrence 110 | self.optimizer.step() 111 | self.optimizer.zero_grad() 112 | current_hiddens = current_hiddens.detach() 113 | 114 | # detacher hiddens depending on whether it is a new episode or not 115 | # todo: restructure dataloader to load one episode per batch rows 116 | 117 | maybe_detached_hiddens = [] 118 | for current_hidden, done in zip(current_hiddens.unbind(dim = 0), dones.unbind(dim = 0)): 119 | maybe_detached_hiddens.append(current_hidden.detached() if done else current_hidden) 120 | 121 | current_hiddens = torch.stack(maybe_detached_hiddens) 122 | 123 | def forward( 124 | self, 125 | noise_strength = 0. 126 | ): 127 | device = next(self.parameters()).device 128 | 129 | time = 0 130 | done = False 131 | states = self.env.reset() 132 | memories = deque([]) 133 | 134 | hidden = self.anymal.student.get_gru_hiddens() 135 | hidden = rearrange(hidden, 'l d -> 1 l d') 136 | 137 | self.running_proprio.clear() 138 | self.running_extero.clear() 139 | 140 | for timestep in range(self.max_timesteps): 141 | time += 1 142 | 143 | states = list(map(lambda t: t.to(device), states)) 144 | anymal_states = list(map(lambda t: rearrange(t, '... -> 1 ...'), states)) 145 | 146 | # teacher needs to have normalized observations 147 | 148 | (proprio, extero, privileged) = states 149 | 150 | self.running_proprio.push(proprio) 151 | self.running_extero.push(extero) 152 | 153 | teacher_states = ( 154 | self.running_proprio.norm(proprio), 155 | self.running_extero.norm(extero) 156 | ) 157 | 158 | teacher_anymal_states = list(map(lambda t: rearrange(t, '... -> 1 ...'), teacher_states)) 159 | 160 | # add states to memories 161 | 162 | memories.append(( 163 | states, 164 | teacher_states, 165 | rearrange(hidden, '1 ... -> ...'), 166 | done 167 | )) 168 | 169 | dist, hidden = self.anymal.forward_student( 170 | *anymal_states[:-1], 171 | hiddens = hidden, 172 | return_action_categorical_dist = True 173 | ) 174 | 175 | action = dist.sample() 176 | action_log_prob = dist.log_prob(action) 177 | action = action.item() 178 | 179 | next_states, _, done, _ = self.env.step(action) 180 | 181 | states = next_states 182 | 183 | if time % self.update_timesteps == 0: 184 | self.learn_from_memories(memories, next_states, noise_strength = noise_strength) 185 | memories.clear() 186 | 187 | if done: 188 | break 189 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Belief State Encoder / Decoder (Anymal) - Pytorch 4 | 5 | Implementation of the Belief State Encoder / Decoder in the new breakthrough robotics paper from ETH Zürich. 6 | 7 | This paper is important as it seems their learned approach produced a policy that rivals Boston Dynamic's handcrafted algorithms (quadripedal Spot). 8 | 9 | The results speak for itself in their video demonstration 10 | 11 | ## Install 12 | 13 | ```bash 14 | $ pip install anymal-belief-state-encoder-decoder-pytorch 15 | ``` 16 | 17 | ## Usage 18 | 19 | Teacher 20 | 21 | ```python 22 | import torch 23 | from anymal_belief_state_encoder_decoder_pytorch import Teacher 24 | 25 | teacher = Teacher( 26 | num_actions = 10, 27 | num_legs = 4, 28 | extero_dim = 52, 29 | proprio_dim = 133, 30 | privileged_dim = 50 31 | ) 32 | 33 | proprio = torch.randn(1, 133) 34 | extero = torch.randn(1, 4, 52) 35 | privileged = torch.randn(1, 50) 36 | 37 | action_logits, values = teacher(proprio, extero, privileged, return_values = True) # (1, 10) 38 | ``` 39 | 40 | Student 41 | 42 | ```python 43 | import torch 44 | from anymal_belief_state_encoder_decoder_pytorch import Student 45 | 46 | student = Student( 47 | num_actions = 10, 48 | num_legs = 4, 49 | extero_dim = 52, 50 | proprio_dim = 133, 51 | gru_num_layers = 2, 52 | gru_hidden_size = 50 53 | ) 54 | 55 | proprio = torch.randn(1, 133) 56 | extero = torch.randn(1, 4, 52) 57 | 58 | action_logits, hiddens = student(proprio, extero) # (1, 10), (2, 1, 50) 59 | action_logits, hiddens = student(proprio, extero, hiddens) # (1, 10), (2, 1, 50) 60 | action_logits, hiddens = student(proprio, extero, hiddens) # (1, 10), (2, 1, 50) 61 | 62 | # hiddens are in the shape (num gru layers, batch size, gru hidden dimension) 63 | # train with truncated bptt 64 | ``` 65 | 66 | Full Anymal (which contains both Teacher and Student) 67 | 68 | ```python 69 | import torch 70 | from anymal_belief_state_encoder_decoder_pytorch import Anymal 71 | 72 | anymal = Anymal( 73 | num_actions = 10, 74 | num_legs = 4, 75 | extero_dim = 52, 76 | proprio_dim = 133, 77 | privileged_dim = 50, 78 | recon_loss_weight = 0.5 79 | ) 80 | 81 | # mock data 82 | 83 | proprio = torch.randn(1, 133) 84 | extero = torch.randn(1, 4, 52) 85 | privileged = torch.randn(1, 50) 86 | 87 | # first train teacher 88 | 89 | teacher_action_logits = anymal.forward_teacher(proprio, extero, privileged) 90 | 91 | # teacher is trained with privileged information in simulation with domain randomization 92 | 93 | # after teacher has satisfactory performance, init the student with the teacher weights, excluding the privilege information encoder from the teacher (which student does not have) 94 | 95 | anymal.init_student_with_teacher() 96 | 97 | # then train the student on the proprioception and noised exteroception, forcing it to reconstruct the privileged information that the teacher had access to (as well as learning to denoise the exterception) - there is also a behavior loss between the policy logits of the teacher with those of the student 98 | 99 | loss, hiddens = anymal(proprio, extero, privileged) 100 | loss.backward() 101 | 102 | # finally, you can deploy the student to the real world, zero-shot 103 | 104 | anymal.eval() 105 | dist, hiddens = anymal.forward_student(proprio, extero, return_action_categorical_dist = True) 106 | action = dist.sample() 107 | ``` 108 | 109 | PPO training of the Teacher (using a mock environment, this needs to be substituted with a environment wrapper around simulator) 110 | 111 | ```python 112 | import torch 113 | from anymal_belief_state_encoder_decoder_pytorch import Anymal, PPO 114 | from anymal_belief_state_encoder_decoder_pytorch.ppo import MockEnv 115 | 116 | anymal = Anymal( 117 | num_actions = 10, 118 | num_legs = 4, 119 | extero_dim = 52, 120 | proprio_dim = 133, 121 | privileged_dim = 50, 122 | recon_loss_weight = 0.5 123 | ) 124 | 125 | mock_env = MockEnv( 126 | proprio_dim = 133, 127 | extero_dim = 52, 128 | privileged_dim = 50 129 | ) 130 | 131 | ppo = PPO( 132 | env = mock_env, 133 | anymal = anymal, 134 | epochs = 10, 135 | lr = 3e-4, 136 | eps_clip = 0.2, 137 | beta_s = 0.01, 138 | value_clip = 0.4, 139 | max_timesteps = 10000, 140 | update_timesteps = 5000, 141 | ) 142 | 143 | # train for 10 episodes 144 | 145 | for _ in range(10): 146 | ppo() 147 | 148 | # save the weights of the teacher for student training 149 | 150 | torch.save(anymal.state_dict(), './anymal-with-trained-teacher.pt') 151 | ``` 152 | 153 | To train the student 154 | 155 | ```python 156 | import torch 157 | from anymal_belief_state_encoder_decoder_pytorch import Anymal 158 | from anymal_belief_state_encoder_decoder_pytorch.trainer import StudentTrainer 159 | from anymal_belief_state_encoder_decoder_pytorch.ppo import MockEnv 160 | 161 | anymal = Anymal( 162 | num_actions = 10, 163 | num_legs = 4, 164 | extero_dim = 52, 165 | proprio_dim = 133, 166 | privileged_dim = 50, 167 | recon_loss_weight = 0.5 168 | ) 169 | 170 | # first init student with teacher weights, at the very beginning 171 | # if not resuming training 172 | 173 | mock_env = MockEnv( 174 | proprio_dim = 133, 175 | extero_dim = 52, 176 | privileged_dim = 50 177 | ) 178 | 179 | trainer = StudentTrainer( 180 | anymal = anymal, 181 | env = mock_env 182 | ) 183 | 184 | # for 100 episodes 185 | 186 | for _ in range(100): 187 | trainer() 188 | 189 | ``` 190 | 191 | ... You've beaten Boston Dynamics and its team of highly paid control engineers! 192 | 193 | But you probably haven't beaten a real quadripedal "anymal" just yet :) 194 | 195 | 196 | 197 | 198 | ## Todo 199 | 200 | - [x] finish belief state decoder 201 | - [x] wrapper class that instantiates both teacher and student, handle student forward pass with reconstruction loss + behavioral loss 202 | - [x] handle noising of exteroception for student 203 | - [x] add basic PPO logic for teacher 204 | - [x] add basic student training loop with mock environment 205 | - [x] make sure all hyperparameters for teacher PPO training + teacher / student distillation is in accordance with appendix 206 | - [ ] noise scheduler for student (curriculum factor that goes from 0 to 1 from epochs 1 to 100) 207 | - [ ] fix student training, it does not look correct 208 | - [ ] make sure tbptt is setup correctly 209 | - [ ] add reward crafting as in paper 210 | - [ ] play around with deepminds mujoco 211 | 212 | ## Diagrams 213 | 214 | 215 | 216 | ## Citations 217 | 218 | ```bibtex 219 | @article{2022, 220 | title = {Learning robust perceptive locomotion for quadrupedal robots in the wild}, 221 | url = {http://dx.doi.org/10.1126/scirobotics.abk2822}, 222 | journal = {Science Robotics}, 223 | publisher = {American Association for the Advancement of Science (AAAS)}, 224 | author = {Miki, Takahiro and Lee, Joonho and Hwangbo, Jemin and Wellhausen, Lorenz and Koltun, Vladlen and Hutter, Marco}, 225 | year = {2022}, 226 | month = {Jan} 227 | } 228 | ``` 229 | -------------------------------------------------------------------------------- /anymal_belief_state_encoder_decoder_pytorch/ppo.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple, deque 2 | 3 | import torch 4 | from torch import nn, cat, stack 5 | from torch.nn import Module 6 | import torch.nn.functional as F 7 | from torch.utils.data import Dataset, DataLoader, TensorDataset 8 | from torch.optim import Adam 9 | 10 | from assoc_scan import AssocScan 11 | 12 | from anymal_belief_state_encoder_decoder_pytorch import Anymal 13 | from anymal_belief_state_encoder_decoder_pytorch.networks import unfreeze_all_layers_ 14 | 15 | from einops import rearrange 16 | 17 | # helper functions 18 | 19 | def exists(val): 20 | return val is not None 21 | 22 | def default(v, d): 23 | return v if exists(v) else d 24 | 25 | # they use basic PPO for training the teacher with privileged information 26 | # then they used noisy student training, using the trained "oracle" teacher as guide 27 | 28 | # ppo data 29 | 30 | Memory = namedtuple('Memory', ['state', 'action', 'action_log_prob', 'reward', 'done', 'value']) 31 | 32 | def create_shuffled_dataloader(data, batch_size): 33 | ds = TensorDataset(*data) 34 | return DataLoader(ds, batch_size = batch_size, shuffle = True) 35 | 36 | # ppo helper functions 37 | 38 | def normalize(t, eps = 1e-5): 39 | return (t - t.mean()) / (t.std() + eps) 40 | 41 | # generalized advantage estimate 42 | 43 | def calc_generalized_advantage_estimate( 44 | rewards, 45 | values, 46 | masks, 47 | gamma = 0.99, 48 | lam = 0.95, 49 | use_accelerated = None 50 | ): 51 | device, is_cuda = rewards.device, rewards.is_cuda 52 | use_accelerated = default(use_accelerated, is_cuda) 53 | 54 | values, values_next = values[:-1], values[1:] 55 | 56 | delta = rewards + gamma * values_next * masks - values 57 | gates = gamma * lam * masks 58 | 59 | scan = AssocScan(reverse = True, use_accelerated = use_accelerated) 60 | 61 | gae = scan(gates, delta) 62 | return gae + values 63 | 64 | def clipped_value_loss(values, rewards, old_values, clip): 65 | value_clipped = old_values + (values - old_values).clamp(-clip, clip) 66 | value_loss_1 = (value_clipped.flatten() - rewards) ** 2 67 | value_loss_2 = (values.flatten() - rewards) ** 2 68 | return torch.mean(torch.max(value_loss_1, value_loss_2)) 69 | 70 | # mock environment 71 | 72 | class MockEnv(object): 73 | def __init__( 74 | self, 75 | proprio_dim, 76 | extero_dim, 77 | privileged_dim, 78 | num_legs = 4 79 | ): 80 | self.proprio_dim = proprio_dim 81 | self.extero_dim = extero_dim 82 | self.privileged_dim = privileged_dim 83 | self.num_legs = num_legs 84 | 85 | def rand_state(self): 86 | return ( 87 | torch.randn((self.proprio_dim,)), 88 | torch.randn((self.num_legs, self.extero_dim,)), 89 | torch.randn((self.privileged_dim,)) 90 | ) 91 | 92 | def reset(self): 93 | return self.rand_state() 94 | 95 | def step(self, action): 96 | reward = torch.randn((1,)) 97 | done = torch.tensor([False]) 98 | return self.rand_state(), reward, done, None 99 | 100 | # main ppo class 101 | 102 | class PPO(Module): 103 | def __init__( 104 | self, 105 | *, 106 | env, 107 | anymal, 108 | epochs = 2, 109 | lr = 5e-4, 110 | betas = (0.9, 0.999), 111 | eps_clip = 0.2, 112 | beta_s = 0.005, 113 | value_clip = 0.4, 114 | max_timesteps = 10000, 115 | update_timesteps = 5000, 116 | lam = 0.95, 117 | gamma = 0.99, 118 | minibatch_size = 8300 119 | ): 120 | super().__init__() 121 | assert isinstance(anymal, Anymal) 122 | self.env = env 123 | self.anymal = anymal 124 | 125 | self.minibatch_size = minibatch_size 126 | self.optimizer = Adam(anymal.teacher.parameters(), lr = lr, betas = betas) 127 | self.epochs = epochs 128 | 129 | self.max_timesteps = max_timesteps 130 | self.update_timesteps = update_timesteps 131 | 132 | self.beta_s = beta_s 133 | self.eps_clip = eps_clip 134 | self.value_clip = value_clip 135 | 136 | self.lam = lam 137 | self.gamma = gamma 138 | 139 | # in paper, they said observations fed to teacher were normalized 140 | # by running mean 141 | 142 | self.running_proprio, self.running_extero = anymal.get_observation_running_stats() 143 | 144 | def learn_from_memories( 145 | self, 146 | memories, 147 | next_states 148 | ): 149 | device = next(self.parameters()).device 150 | 151 | # retrieve and prepare data from memory for training 152 | 153 | ( 154 | states, 155 | actions, 156 | old_log_probs, 157 | rewards, 158 | dones, 159 | values 160 | ) = tuple(zip(*memories)) 161 | 162 | states = tuple(zip(*states)) 163 | 164 | # calculate generalized advantage estimate 165 | 166 | rewards = cat(rewards).to(device) 167 | values = cat(values).to(device).detach() 168 | masks = 1. - cat(dones).to(device).float() 169 | 170 | next_states = [t.to(device) for t in next_states] 171 | next_states = [rearrange(t, '... -> 1 ...') for t in next_states] 172 | 173 | with torch.no_grad(): 174 | self.anymal.eval() 175 | _, next_value = self.anymal.forward_teacher(*next_states, return_value_head = True) 176 | next_value = next_value.detach() 177 | 178 | values_with_next = cat((values, next_value)) 179 | 180 | returns = calc_generalized_advantage_estimate(rewards, values_with_next, masks, self.gamma, self.lam).detach() 181 | 182 | # convert values to torch tensors 183 | 184 | to_torch_tensor = lambda t: stack(t).to(device).detach() 185 | 186 | states = map(to_torch_tensor, states) 187 | actions = to_torch_tensor(actions) 188 | old_log_probs = to_torch_tensor(old_log_probs) 189 | 190 | # prepare dataloader for policy phase training 191 | 192 | dl = create_shuffled_dataloader([*states, actions, old_log_probs, rewards, values], self.minibatch_size) 193 | 194 | # policy phase training, similar to original PPO 195 | 196 | for _ in range(self.epochs): 197 | for proprio, extero, privileged, actions, old_log_probs, rewards, old_values in dl: 198 | 199 | dist, values = self.anymal.forward_teacher( 200 | proprio, extero, privileged, 201 | return_value_head = True, 202 | return_action_categorical_dist = True 203 | ) 204 | 205 | action_log_probs = dist.log_prob(actions) 206 | 207 | entropy = dist.entropy() 208 | ratios = (action_log_probs - old_log_probs).exp() 209 | advantages = normalize(rewards - old_values.detach()) 210 | surr1 = ratios * advantages 211 | surr2 = ratios.clamp(1 - self.eps_clip, 1 + self.eps_clip) * advantages 212 | 213 | policy_loss = - torch.min(surr1, surr2) - self.beta_s * entropy 214 | 215 | value_loss = clipped_value_loss(values, rewards, old_values, self.value_clip) 216 | 217 | (policy_loss.mean() + value_loss.mean()).backward() 218 | self.optimizer.step() 219 | self.optimizer.zero_grad() 220 | 221 | # does one episodes worth of learning 222 | 223 | def forward(self): 224 | device = next(self.parameters()).device 225 | unfreeze_all_layers_(self.anymal) 226 | 227 | time = 0 228 | states = self.env.reset() # states assumed to be (proprioception, exteroception, privileged information) 229 | memories = deque([]) 230 | 231 | self.running_proprio.clear() 232 | self.running_extero.clear() 233 | 234 | for timestep in range(self.max_timesteps): 235 | time += 1 236 | 237 | states = list(map(lambda t: t.to(device), states)) 238 | proprio, extero, privileged = states 239 | 240 | # update running means for observations, for teacher 241 | 242 | self.running_proprio.push(proprio) 243 | self.running_extero.push(extero) 244 | 245 | # normalize observation states for teacher (proprio and extero) 246 | 247 | states = ( 248 | self.running_proprio.norm(proprio), 249 | self.running_extero.norm(extero), 250 | privileged 251 | ) 252 | 253 | anymal_states = list(map(lambda t: rearrange(t, '... -> 1 ...'), states)) 254 | 255 | dist, values = self.anymal.forward_teacher( 256 | *anymal_states, 257 | return_value_head = True, 258 | return_action_categorical_dist = True 259 | ) 260 | 261 | action = dist.sample() 262 | action_log_prob = dist.log_prob(action) 263 | 264 | next_states, reward, done, _ = self.env.step(action) 265 | 266 | memory = Memory(states, action, action_log_prob, reward, done, values) 267 | memories.append(memory) 268 | 269 | states = next_states 270 | 271 | if time % self.update_timesteps == 0: 272 | self.learn_from_memories(memories, next_states) 273 | memories.clear() 274 | 275 | if done: 276 | break 277 | 278 | print('trained for 1 episode') 279 | -------------------------------------------------------------------------------- /anymal_belief_state_encoder_decoder_pytorch/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from torch.nn import GRUCell 5 | from torch.distributions import Categorical 6 | from torch.optim import Adam 7 | 8 | from einops import rearrange 9 | from einops.layers.torch import Rearrange 10 | 11 | from anymal_belief_state_encoder_decoder_pytorch.running import RunningStats 12 | 13 | # helper functions 14 | 15 | def exists(val): 16 | return val is not None 17 | 18 | def check_shape(tensor, pattern, **kwargs): 19 | return rearrange(tensor, f'{pattern} -> {pattern}', **kwargs) 20 | 21 | # freezing of neural networks (teacher needs to be frozen) 22 | 23 | def set_module_requires_grad_(module, requires_grad): 24 | for param in module.parameters(): 25 | param.requires_grad = requires_grad 26 | 27 | def freeze_all_layers_(module): 28 | set_module_requires_grad_(module, False) 29 | 30 | def unfreeze_all_layers_(module): 31 | set_module_requires_grad_(module, True) 32 | 33 | # in the paper 34 | # the network attention gates the exteroception, and then sums it to the belief state 35 | # todo: make sure the padding is on the right side 36 | 37 | def sum_with_zeropad(x, y): 38 | x_dim, y_dim = x.shape[-1], y.shape[-1] 39 | 40 | if x_dim == y_dim: 41 | return x + y 42 | 43 | if x_dim < y_dim: 44 | x = F.pad(x, (y_dim - x_dim, 0)) 45 | 46 | if y_dim < x_dim: 47 | y = F.pad(y, (x_dim - y_dim, 0)) 48 | 49 | return x + y 50 | 51 | # add basic MLP 52 | 53 | class MLP(nn.Module): 54 | def __init__( 55 | self, 56 | dims, 57 | activation = nn.LeakyReLU, 58 | final_activation = False 59 | ): 60 | super().__init__() 61 | assert isinstance(dims, (list, tuple)) 62 | assert len(dims) > 2, 'must have at least 3 dimensions (input, *hiddens, output)' 63 | 64 | dim_pairs = list(zip(dims[:-1], dims[1:])) 65 | *dim_pairs, dim_out_pair = dim_pairs 66 | 67 | layers = [] 68 | for dim_in, dim_out in dim_pairs: 69 | layers.extend([ 70 | nn.Linear(dim_in, dim_out), 71 | activation() 72 | ]) 73 | 74 | layers.append(nn.Linear(*dim_out_pair)) 75 | 76 | if final_activation: 77 | layers.append(activation()) 78 | 79 | self.net = nn.Sequential(*layers) 80 | 81 | def forward(self, x): 82 | if isinstance(x, (tuple, list)): 83 | x = torch.cat(x, dim = -1) 84 | 85 | return self.net(x) 86 | 87 | class Student(nn.Module): 88 | def __init__( 89 | self, 90 | num_actions, 91 | proprio_dim = 133, 92 | extero_dim = 52, # in paper, height samples was marked as 208, but wasn't sure if that was per leg, or (4 legs x 52) = 208 93 | latent_extero_dim = 24, 94 | extero_encoder_hidden = (80, 60), 95 | belief_state_encoder_hiddens = (64, 64), 96 | extero_gate_encoder_hiddens = (64, 64), 97 | belief_state_dim = 120, # should be equal to teacher's extero_dim + privileged_dim (part of the GRU's responsibility is to maintain a hidden state that forms an opinion on the privileged information) 98 | gru_num_layers = 2, 99 | gru_hidden_size = 50, 100 | mlp_hidden = (256, 160, 128), 101 | num_legs = 4, 102 | privileged_dim = 50, 103 | privileged_decoder_hiddens = (64, 64), 104 | extero_decoder_hiddens = (64, 64), 105 | ): 106 | super().__init__() 107 | assert belief_state_dim > (num_legs * latent_extero_dim) 108 | self.num_legs = num_legs 109 | self.proprio_dim = proprio_dim 110 | self.extero_dim = extero_dim 111 | 112 | # encoding of exteroception 113 | 114 | self.extero_encoder = MLP((extero_dim, *extero_encoder_hidden, latent_extero_dim)) 115 | 116 | # GRU related parameters 117 | 118 | gru_input_dim = (latent_extero_dim * num_legs) + proprio_dim 119 | gru_input_dims = (gru_input_dim, *((gru_hidden_size,) * (gru_num_layers - 1))) 120 | self.gru_cells = nn.ModuleList([GRUCell(input_dim, gru_hidden_size) for input_dim in gru_input_dims]) 121 | self.gru_hidden_size = gru_hidden_size 122 | 123 | # belief state encoding 124 | 125 | self.belief_state_encoder = MLP((gru_hidden_size, *belief_state_encoder_hiddens, belief_state_dim)) 126 | 127 | # attention gating of exteroception 128 | 129 | self.to_latent_extero_attn_gate = MLP((gru_hidden_size, *extero_gate_encoder_hiddens, latent_extero_dim * num_legs)) 130 | 131 | # belief state decoder 132 | 133 | self.privileged_decoder = MLP((gru_hidden_size, *privileged_decoder_hiddens, privileged_dim)) 134 | self.extero_decoder = MLP((gru_hidden_size, *extero_decoder_hiddens, extero_dim * num_legs)) 135 | 136 | self.to_extero_attn_gate = MLP((gru_hidden_size, *extero_gate_encoder_hiddens, extero_dim * num_legs)) 137 | 138 | # final MLP to action logits 139 | 140 | self.to_logits = MLP(( 141 | belief_state_dim + proprio_dim, 142 | *mlp_hidden 143 | )) 144 | 145 | self.to_action_head = nn.Sequential( 146 | nn.LeakyReLU(), 147 | nn.Linear(mlp_hidden[-1], num_actions) 148 | ) 149 | 150 | def get_gru_hiddens(self): 151 | device = next(self.parameters()).device 152 | return torch.zeros((len(self.gru_cells), self.gru_hidden_size)) 153 | 154 | def forward( 155 | self, 156 | proprio, 157 | extero, 158 | hiddens = None, 159 | return_estimated_info = False, # for returning estimated privileged info + exterceptive info, for reconstruction loss 160 | return_action_categorical_dist = False 161 | ): 162 | check_shape(proprio, 'b d', d = self.proprio_dim) 163 | check_shape(extero, 'b n d', n = self.num_legs, d = self.extero_dim) 164 | 165 | latent_extero = self.extero_encoder(extero) 166 | latent_extero = rearrange(latent_extero, 'b ... -> b (...)') 167 | 168 | # RNN 169 | 170 | if not exists(hiddens): 171 | prev_hiddens = (None,) * len(self.gru_cells) 172 | else: 173 | prev_hiddens = hiddens.unbind(dim = -2) 174 | 175 | gru_input = torch.cat((proprio, latent_extero), dim = -1) 176 | 177 | next_hiddens = [] 178 | for gru_cell, prev_hidden in zip(self.gru_cells, prev_hiddens): 179 | gru_input = gru_cell(gru_input, prev_hidden) 180 | next_hiddens.append(gru_input) 181 | 182 | gru_output = gru_input 183 | 184 | next_hiddens = torch.stack(next_hiddens, dim = -2) 185 | 186 | # attention gating of exteroception 187 | 188 | latent_extero_attn_gate = self.to_latent_extero_attn_gate(gru_output) 189 | gated_latent_extero = latent_extero * latent_extero_attn_gate.sigmoid() 190 | 191 | # belief state and add gated exteroception 192 | 193 | belief_state = self.belief_state_encoder(gru_output) 194 | belief_state = sum_with_zeropad(belief_state, gated_latent_extero) 195 | 196 | # to action logits 197 | 198 | belief_state_with_proprio = torch.cat(( 199 | proprio, 200 | belief_state, 201 | ), dim = 1) 202 | 203 | logits = self.to_logits(belief_state_with_proprio) 204 | 205 | pi_logits = self.to_action_head(logits) 206 | 207 | return_action = Categorical(pi_logits.softmax(dim = -1)) if return_action_categorical_dist else pi_logits 208 | 209 | if not return_estimated_info: 210 | return return_action, next_hiddens 211 | 212 | # belief state decoding 213 | # for reconstructing privileged and exteroception information from hidden belief states 214 | 215 | recon_privileged = self.privileged_decoder(gru_output) 216 | recon_extero = self.extero_decoder(gru_output) 217 | extero_attn_gate = self.to_extero_attn_gate(gru_output) 218 | 219 | gated_extero = rearrange(extero, 'b ... -> b (...)') * extero_attn_gate.sigmoid() 220 | recon_extero = recon_extero + gated_extero 221 | recon_extero = rearrange(recon_extero, 'b (n d) -> b n d', n = self.num_legs) 222 | 223 | # whether to return raw policy logits or action probs wrapped with Categorical 224 | 225 | return return_action, next_hiddens, (recon_privileged, recon_extero) 226 | 227 | class Teacher(nn.Module): 228 | def __init__( 229 | self, 230 | num_actions, 231 | proprio_dim = 133, 232 | extero_dim = 52, # in paper, height samples was marked as 208, but wasn't sure if that was per leg, or (4 legs x 52) = 208 233 | latent_extero_dim = 24, 234 | extero_encoder_hidden = (80, 60), 235 | privileged_dim = 50, 236 | latent_privileged_dim = 24, 237 | privileged_encoder_hidden = (64, 32), 238 | mlp_hidden = (256, 160, 128), 239 | num_legs = 4 240 | ): 241 | super().__init__() 242 | self.num_legs = num_legs 243 | self.proprio_dim = proprio_dim 244 | self.extero_dim = extero_dim 245 | self.privileged_dim = privileged_dim 246 | 247 | self.extero_encoder = MLP((extero_dim, *extero_encoder_hidden, latent_extero_dim)) 248 | self.privileged_encoder = MLP((privileged_dim, *privileged_encoder_hidden, latent_privileged_dim)) 249 | 250 | self.to_logits = MLP(( 251 | latent_extero_dim * num_legs + latent_privileged_dim + proprio_dim, 252 | *mlp_hidden 253 | )) 254 | 255 | self.to_action_head = nn.Sequential( 256 | nn.LeakyReLU(), 257 | nn.Linear(mlp_hidden[-1], num_actions) 258 | ) 259 | 260 | self.to_value_head = nn.Sequential( 261 | nn.LeakyReLU(), 262 | nn.Linear(mlp_hidden[-1], 1), 263 | Rearrange('... 1 -> ...') 264 | ) 265 | 266 | def forward( 267 | self, 268 | proprio, 269 | extero, 270 | privileged, 271 | return_value_head = False, 272 | return_action_categorical_dist = False 273 | ): 274 | check_shape(proprio, 'b d', d = self.proprio_dim) 275 | check_shape(extero, 'b n d', n = self.num_legs, d = self.extero_dim) 276 | check_shape(privileged, 'b d', d = self.privileged_dim) 277 | 278 | latent_extero = self.extero_encoder(extero) 279 | latent_extero = rearrange(latent_extero, 'b ... -> b (...)') 280 | 281 | latent_privileged = self.privileged_encoder(privileged) 282 | 283 | latent = torch.cat(( 284 | proprio, 285 | latent_extero, 286 | latent_privileged, 287 | ), dim = -1) 288 | 289 | logits = self.to_logits(latent) 290 | 291 | pi_logits = self.to_action_head(logits) 292 | 293 | if not return_value_head: 294 | return pi_logits 295 | 296 | value_logits = self.to_value_head(logits) 297 | 298 | return_action = Categorical(pi_logits.softmax(dim = -1)) if return_action_categorical_dist else pi_logits 299 | return return_action, value_logits 300 | 301 | # manages both teacher and student under one module 302 | 303 | class Anymal(nn.Module): 304 | def __init__( 305 | self, 306 | num_actions, 307 | proprio_dim = 133, 308 | extero_dim = 52, 309 | privileged_dim = 50, 310 | num_legs = 4, 311 | latent_extero_dim = 24, 312 | latent_privileged_dim = 24, 313 | teacher_extero_encoder_hidden = (80, 60), 314 | teacher_privileged_encoder_hidden = (64, 32), 315 | student_extero_gate_encoder_hiddens = (64, 64), 316 | student_belief_state_encoder_hiddens = (64, 64), 317 | student_belief_state_dim = 120, 318 | student_gru_num_layers = 2, 319 | student_gru_hidden_size = 50, 320 | student_privileged_decoder_hiddens = (64, 64), 321 | student_extero_decoder_hiddens = (64, 64), 322 | student_extero_encoder_hidden = (80, 60), 323 | mlp_hidden = (256, 160, 128), 324 | recon_loss_weight = 0.5 325 | ): 326 | super().__init__() 327 | self.proprio_dim = proprio_dim 328 | self.num_legs = num_legs 329 | self.extero_dim = extero_dim 330 | 331 | self.student = Student( 332 | num_actions = num_actions, 333 | proprio_dim = proprio_dim, 334 | extero_dim = extero_dim, 335 | latent_extero_dim = latent_extero_dim, 336 | extero_encoder_hidden = student_extero_encoder_hidden, 337 | belief_state_encoder_hiddens = student_belief_state_encoder_hiddens, 338 | extero_gate_encoder_hiddens = student_extero_gate_encoder_hiddens, 339 | belief_state_dim = student_belief_state_dim, 340 | gru_num_layers = student_gru_num_layers, 341 | gru_hidden_size = student_gru_hidden_size, 342 | mlp_hidden = mlp_hidden, 343 | num_legs = num_legs, 344 | privileged_dim = privileged_dim, 345 | privileged_decoder_hiddens = student_privileged_decoder_hiddens, 346 | extero_decoder_hiddens = student_extero_decoder_hiddens, 347 | ) 348 | 349 | self.teacher = Teacher( 350 | num_actions = num_actions, 351 | proprio_dim = proprio_dim, 352 | extero_dim = extero_dim, 353 | latent_extero_dim = latent_extero_dim, 354 | extero_encoder_hidden = teacher_extero_encoder_hidden, 355 | privileged_dim = privileged_dim, 356 | latent_privileged_dim = latent_privileged_dim, 357 | privileged_encoder_hidden = teacher_privileged_encoder_hidden, 358 | mlp_hidden = mlp_hidden, 359 | num_legs = num_legs 360 | ) 361 | 362 | self.recon_loss_weight = recon_loss_weight 363 | 364 | def get_observation_running_stats(self): 365 | return RunningStats(self.proprio_dim), RunningStats((self.num_legs, self.extero_dim)) 366 | 367 | def init_student_with_teacher(self): 368 | self.student.extero_encoder.load_state_dict(self.teacher.extero_encoder.state_dict()) 369 | self.student.to_logits.load_state_dict(self.teacher.to_logits.state_dict()) 370 | self.student.to_action_head.load_state_dict(self.teacher.to_action_head.state_dict()) 371 | 372 | def forward_teacher(self, *args, return_value_head = False, **kwargs): 373 | return self.teacher(*args, return_value_head = return_value_head, **kwargs) 374 | 375 | def forward_student(self, *args, **kwargs): 376 | return self.student(*args, **kwargs) 377 | 378 | # main forward for training the student with teacher as guide 379 | 380 | def forward( 381 | self, 382 | proprio, 383 | extero, 384 | privileged, 385 | teacher_states = None, 386 | hiddens = None, 387 | noise_strength = 0.1 388 | ): 389 | self.teacher.eval() 390 | freeze_all_layers_(self.teacher) 391 | 392 | with torch.no_grad(): 393 | teacher_proprio, teacher_extero = teacher_states if exists(teacher_states) else (proprio, extero) 394 | teacher_action_logits = self.forward_teacher(teacher_proprio, teacher_extero, privileged) 395 | 396 | noised_extero = extero + torch.rand_like(extero) * noise_strength 397 | 398 | student_action_logits, hiddens, recons = self.student(proprio, noised_extero, hiddens = hiddens, return_estimated_info = True) 399 | 400 | # calculate reconstruction loss of privileged and denoised exteroception 401 | 402 | (recon_privileged, recon_extero) = recons 403 | recon_loss = F.mse_loss(recon_privileged, privileged) + F.mse_loss(recon_extero, extero) 404 | 405 | # calculate behavior loss, which is also squared distance? 406 | 407 | behavior_loss = F.mse_loss(teacher_action_logits, student_action_logits) # why not kl div on action probs? 408 | 409 | loss = behavior_loss + recon_loss * self.recon_loss_weight 410 | return loss, hiddens 411 | --------------------------------------------------------------------------------