├── .gitignore ├── EXAMPLES.md ├── LICENSE ├── PytorchRouting ├── CoreLayers │ ├── Initialization.py │ ├── Loss.py │ ├── Selection.py │ └── __init__.py ├── DecisionLayers │ ├── Decision.py │ ├── Others │ │ ├── GumbelSoftmax.py │ │ ├── PassThrough.py │ │ ├── PerTaskAssignment.py │ │ ├── RELAX.py │ │ └── __init__.py │ ├── PolicyStorage.py │ ├── ReinforcementLearning │ │ ├── AAC.py │ │ ├── ActorCritic.py │ │ ├── AdvantageLearning.py │ │ ├── QLearning.py │ │ ├── REINFORCE.py │ │ ├── SARSA.py │ │ ├── WPL.py │ │ └── __init__.py │ └── __init__.py ├── Examples │ ├── Datasets.py │ ├── Models.py │ ├── __init__.py │ └── run_experiments.py ├── Helpers │ ├── MLP.py │ ├── RLSample.py │ ├── SampleMetaInformation.py │ ├── TorchHelpers.py │ └── __init__.py ├── PreFabs │ ├── RNNcells.py │ └── __init__.py ├── RewardFunctions │ ├── Final │ │ ├── BaseReward.py │ │ ├── CorrectClassifiedReward.py │ │ ├── NegLossReward.py │ │ └── __init__.py │ ├── PerAction │ │ ├── CollaborationReward.py │ │ ├── ManualReward.py │ │ ├── PerActionBaseReward.py │ │ └── __init__.py │ └── __init__.py ├── UtilLayers │ ├── Sequential.py │ └── __init__.py └── __init__.py ├── README.md └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # datasets 29 | *.pkl.* 30 | Datasets/ 31 | 32 | # PyCharm 33 | .idea/ 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # pyenv 83 | .python-version 84 | 85 | # celery beat schedule file 86 | celerybeat-schedule 87 | 88 | # SageMath parsed files 89 | *.sage.py 90 | 91 | # Environments 92 | .env 93 | .venv 94 | env/ 95 | venv/ 96 | ENV/ 97 | env.bak/ 98 | venv.bak/ 99 | 100 | # Spyder project settings 101 | .spyderproject 102 | .spyproject 103 | 104 | # Rope project settings 105 | .ropeproject 106 | 107 | # mkdocs documentation 108 | /site 109 | 110 | # mypy 111 | .mypy_cache/ 112 | -------------------------------------------------------------------------------- /EXAMPLES.md: -------------------------------------------------------------------------------- 1 | ### Examples 2 | The following examples can be found in `PytorchRouting.Examples.Models`, and can be tested on CIFAR100-MTL and MNIST-MTL with `PytorchRouting.Examples.run_experiment.py`. 3 | 4 | #### Per-Task Agents 5 | The most architectures introduced in the paper assign one agent exclusively to each task. Using Pytorch-Routing, these can be implemented as follows (here, using a WPL MARL agent): 6 | ```Python 7 | class RoutedAllFC(nn.Module): 8 | def __init__(self, in_channels, convnet_out_size, out_dim, num_modules, num_agents): 9 | 10 | self.convolutions = nn.Sequential( 11 | SimpleConvNetBlock(in_channels, 32, 3), 12 | SimpleConvNetBlock(32, 32, 3), 13 | SimpleConvNetBlock(32, 32, 3), 14 | SimpleConvNetBlock(32, 32, 3), 15 | nn.BatchNorm2d(32), 16 | Flatten() 17 | ) 18 | self._loss_func = Loss(torch.nn.MSELoss(), CorrectClassifiedReward(), discounting=1.) 19 | 20 | self._initialization = Initialization() 21 | self._per_task_assignment = PerTaskAssignment() 22 | 23 | self._decision_1 = WPL( 24 | num_modules, convnet_out_size, num_agents=num_agents, policy_storage_type='tabular', 25 | additional_reward_func=CollaborationReward(reward_ratio=0.3, num_actions=num_modules)) 26 | self._decision_2 = WPL( 27 | num_modules, convnet_out_size, num_agents=num_agents, policy_storage_type='tabular', 28 | additional_reward_func=CollaborationReward(reward_ratio=0.3, num_actions=num_modules)) 29 | self._decision_3 = WPL( 30 | num_modules, convnet_out_size, num_agents=num_agents, policy_storage_type='tabular', 31 | additional_reward_func=CollaborationReward(reward_ratio=0.3, num_actions=num_modules)) 32 | 33 | self._selection_1 = Selection(*[LinearWithRelu(convnet_out_size, 48) for _ in range(num_modules)]) 34 | self._selection_2 = Selection(*[LinearWithRelu(48, 48) for _ in range(num_modules)]) 35 | self._selection_3 = Selection(*[nn.Linear(48, out_dim) for _ in range(num_modules)]) 36 | # self._selection_f = Selection(*[nn.Linear(48, out_dim) for _ in range(num_modules)]) 37 | 38 | def forward(self, x, tasks): 39 | y = self.convolutions(x) 40 | y, meta, actions = self._initialization(y, tasks=tasks) 41 | y, meta, task_actions = self._per_task_assignment(y, meta, actions) 42 | y, meta, routing_actions_1 = self._decision_1(y, meta, task_actions) 43 | y, meta, _ = self._selection_1(y, meta, routing_actions_1) 44 | y, meta, routing_actions_2 = self._decision_2(y, meta, task_actions) 45 | y, meta, _ = self._selection_2(y, meta, routing_actions_2) 46 | y, meta, routing_actions_3 = self._decision_3(y, meta, task_actions) 47 | y, meta, _ = self._selection_3(y, meta, routing_actions_3) 48 | return y, meta 49 | 50 | def loss(self, yhat, ytrue, ym): 51 | return self._lossfunc(yhat, ytrue, ym) 52 | ``` 53 | Again, the `PerTaskAssignment` layer is utilized to produce actions. However, these actions are not used to select a module, but to select an agent instead. Thus they cannot be overridden, but are explicitly passed to each of the decision making agents as dispatcher-actions. 54 | 55 | #### Dispatched Routing Architectures 56 | Extending this paradigm to dispatched architectures is straightforward: 57 | ```Python 58 | class DispatchedRoutedAllFC(RoutedAllFC): 59 | def __init__(self, dispatcher_decision_maker, decision_maker, in_channels, convnet_out_size, 60 | out_dim, num_modules, num_agents): 61 | RoutedAllFC.__init__(self, decision_maker, in_channels, convnet_out_size, out_dim, num_modules, num_agents) 62 | self._per_task_assignment = dispatcher_decision_maker( 63 | num_agents, convnet_out_size, num_agents=1, policy_storage_type='approx', 64 | additional_reward_func=CollaborationReward(reward_ratio=0.0, num_actions=num_modules)) 65 | ``` 66 | Here, the task-specific assignment "agent" simply got replaced by a separate dispatching agent. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /PytorchRouting/CoreLayers/Initialization.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file defines class RoutingTechnicalLayers. 3 | 4 | @author: Clemens Rosenbaum :: cgbr@cs.umass.edu 5 | @created: 6/8/18 6 | """ 7 | import torch.nn as nn 8 | 9 | from PytorchRouting.Helpers.SampleMetaInformation import SampleMetaInformation 10 | 11 | 12 | class Initialization(nn.Module): 13 | """ 14 | The initialization class defines a thin layer that initializes the meta-information and actions - composing 15 | the pytorch-routing information triplet. 16 | """ 17 | 18 | def __init__(self): 19 | nn.Module.__init__(self) 20 | 21 | def forward(self, xs, tasks=()): 22 | if len(tasks) > 0: 23 | mxs = [SampleMetaInformation(task=t) for t in tasks] 24 | else: 25 | mxs = [SampleMetaInformation() for _ in xs] 26 | return xs, mxs, None 27 | -------------------------------------------------------------------------------- /PytorchRouting/CoreLayers/Loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file defines class RoutingLoss. 3 | 4 | @author: Clemens Rosenbaum :: cgbr@cs.umass.edu 5 | @created: 6/8/18 6 | """ 7 | import abc 8 | from collections import defaultdict 9 | 10 | import torch 11 | import torch.nn as nn 12 | 13 | from PytorchRouting.RewardFunctions.Final import CorrectClassifiedReward 14 | 15 | 16 | class Loss(nn.Module, metaclass=abc.ABCMeta): 17 | """ 18 | This function defines the combined module/decision loss functions. It performs four steps that will result in 19 | separate losses for the modules and the decision makers: 20 | 1. it computes the module losses 21 | 2. it translates these module losses into per-sample reinforcement learning rewards 22 | 3. it uses these final rewards to compute the full rl-trajectories for each sample 23 | 4. it uses the decision-making specific loss functions to compute the total decision making loss 24 | """ 25 | 26 | def __init__(self, pytorch_loss_func, routing_reward_func, discounting=1., clear=False, 27 | normalize_per_action_rewards=True): 28 | nn.Module.__init__(self) 29 | self._discounting = discounting 30 | self._loss_func = pytorch_loss_func 31 | self._clear = clear 32 | try: 33 | self._loss_func.reduction = 'none' 34 | except AttributeError: 35 | pass 36 | self._reward_func = routing_reward_func 37 | self._npar = normalize_per_action_rewards 38 | 39 | def _get_rl_loss_tuple_map(self, mys, device): 40 | rl_loss_tuple_map = defaultdict(lambda: defaultdict(list)) 41 | reward_functions = set() 42 | for traj_counter, my in zip(torch.arange(len(mys), device=device).unsqueeze(1), mys): 43 | my.finalize() # translates the trajectory from a list of obj into lists 44 | my.add_rewards = [ar if ar is not None else 0. for ar in my.add_rewards] \ 45 | if hasattr(my, 'add_rewards') else [0.] * len(my.actions) 46 | assert len(my.actions) == len(my.states) == len(my.add_rewards) == len(my.reward_func) 47 | rewards = [] 48 | # computing the rewards 49 | for state, action, reward_func, add_r in zip(my.states, my.actions, my.reward_func, my.add_rewards): 50 | # normalize the per-action reward to the entire sequence length 51 | per_action_reward = (reward_func.get_reward(state, action) + add_r) / len(my.actions) 52 | # normalize to the final reward, s.t. it will be interpreted as a fraction thereof 53 | per_action_reward = per_action_reward * torch.abs(my.final_reward) if self._npar else per_action_reward 54 | rewards.append(per_action_reward) 55 | reward_functions.add(reward_func) 56 | rewards[-1] += my.final_reward 57 | returns = [0.] 58 | # computing the returns 59 | for i, rew in enumerate(reversed(rewards)): 60 | returns.insert(0, rew + returns[0]) 61 | returns = returns[:-1] 62 | # creating the tensors to compute the loss from the SARSA tuple 63 | for lf, s, a, rew, ret, pa, ns, na in zip(my.loss_funcs, my.states, my.actions, rewards, returns, 64 | ([None] + my.actions)[:-1], 65 | (my.states + [None])[1:], 66 | (my.actions + [None])[1:]): 67 | is_terminal = ns is None or s.numel() != ns.numel() 68 | rl_loss_tuple_map[lf]['indices'].append(traj_counter) 69 | rl_loss_tuple_map[lf]['is_terminal'].append(torch.tensor([is_terminal], dtype=torch.uint8, device=device)) 70 | rl_loss_tuple_map[lf]['states'].append(s) 71 | rl_loss_tuple_map[lf]['actions'].append(a.view(-1)) 72 | rl_loss_tuple_map[lf]['rewards'].append(rew.view(-1)) 73 | rl_loss_tuple_map[lf]['returns'].append(ret.view(-1)) 74 | rl_loss_tuple_map[lf]['final_reward'].append(my.final_reward.view(-1)) 75 | rl_loss_tuple_map[lf]['prev_actions'].append(a.new_zeros(1) if pa is None else pa.view(-1)) 76 | rl_loss_tuple_map[lf]['next_states'].append(s if is_terminal else ns) 77 | rl_loss_tuple_map[lf]['next_actions'].append(a.new_zeros(1) if is_terminal else na.view(-1)) 78 | # concatenating the retrieved values into tensors 79 | for k0, v0 in rl_loss_tuple_map.items(): 80 | for k1, v1 in v0.items(): 81 | v0[k1] = torch.cat(v1, dim=0) 82 | if self._clear: 83 | for rf in reward_functions: 84 | rf.clear() 85 | return rl_loss_tuple_map 86 | 87 | def forward(self, ysest, mys, ystrue=None, external_losses=None, reduce=True): 88 | assert not(ystrue is None and external_losses is None), \ 89 | 'Must provide ystrue and possibly external_losses (or both).' 90 | batch_size = ysest.size(0) 91 | if external_losses is not None: 92 | # first case: external losses are provided externally 93 | assert external_losses.size()[0] == len(mys), 'One loss value per sample is required.' 94 | module_loss = external_losses.view(external_losses.size()[0], -1).sum(dim=1) 95 | else: 96 | # second case: they are not, so we need to compute them 97 | module_loss = self._loss_func(ysest, ystrue) 98 | if len(module_loss.size()) > 1: 99 | module_loss = module_loss.sum(dim=1).reshape(-1) 100 | if ystrue is None: 101 | # more input checking 102 | assert not isinstance(self._reward_func, CorrectClassifiedReward), \ 103 | 'Must provide ystrue when using CorrectClassifiedReward' 104 | ystrue = ysest.new_zeros(batch_size) 105 | assert len(module_loss) == len(mys) == len(ysest) == len(ystrue), \ 106 | 'Losses, metas, predictions and targets need to have the same length ({}, {}, {}, {})'.format( 107 | len(module_loss), len(mys), len(ysest), len(ystrue)) 108 | # add the final reward, as we can only compute them now that we have the external feedback 109 | for l, my, yest, ytrue in zip(module_loss.split(1, dim=0), mys, ysest.split(1, dim=0), ystrue.split(1, dim=0)): 110 | my.final_reward = self._reward_func(l, yest, ytrue) 111 | # retrieve the SARSA pairs to compute the respective decision making losses 112 | rl_loss_tuple_map = self._get_rl_loss_tuple_map(mys, device=ysest.device) 113 | # initialize the rl loss 114 | routing_loss = torch.zeros(batch_size, dtype=torch.float, device=ysest.device) 115 | for loss_func, rl_dict in rl_loss_tuple_map.items(): 116 | # batch the RL loss by loss function, if possible 117 | rl_losses = loss_func(rl_dict['is_terminal'], rl_dict['states'], rl_dict['next_states'], rl_dict['actions'], 118 | rl_dict['next_actions'], rl_dict['rewards'], rl_dict['returns'], rl_dict['final_reward']) 119 | for i in torch.arange(batch_size, device=ysest.device): 120 | # map the losses back onto the sample indices 121 | routing_loss[i] = routing_loss[i] + torch.sum(rl_losses[rl_dict['indices'] == i]) 122 | if reduce: 123 | module_loss = module_loss.mean() 124 | routing_loss = routing_loss.mean() 125 | return module_loss, routing_loss 126 | -------------------------------------------------------------------------------- /PytorchRouting/CoreLayers/Selection.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file defines class Model. 3 | 4 | @author: Clemens Rosenbaum :: cgbr@cs.umass.edu 5 | @created: 6/6/18 6 | """ 7 | import torch 8 | import torch.nn as nn 9 | # from torch.multiprocessing import Pool 10 | 11 | 12 | class Selection(nn.Module): 13 | """ 14 | Class RoutingWrapperModule defines a wrapper around a regular pytorch module that computes the actual routing 15 | given a list of modules to choose from, and a list of actions to select a module for each sample in a batch. 16 | """ 17 | 18 | def __init__(self, *modules, name='', store_module_pointers=False): 19 | nn.Module.__init__(self) 20 | self.name = name 21 | # self._threads = threads 22 | self._submodules = nn.ModuleList(modules) 23 | self._selection_log = [] 24 | self._logging_selections = False 25 | self.__output_dim = None 26 | self._store_module_pointers = store_module_pointers 27 | 28 | def forward(self, xs, mxs, actions, mask=None): 29 | """ 30 | This method takes a list of samples - a batch - and calls _forward_sample on each. Samples are 31 | a tensor where the first dimension is the batch dimension. 32 | :param xs: 33 | :param mxs: 34 | :param actions: 35 | :param mask: a torch.ByteTensor that determines if the trajectory is active. if it is not, no action 36 | will be executed 37 | :return: 38 | """ 39 | assert len(xs) == len(mxs) 40 | batch_size = xs.size(0) 41 | # capture the special case of just one submodule - and skip all computation 42 | if len(self._submodules) == 1: 43 | return self._submodules[0](xs), mxs, actions 44 | # retrieving output dim for output instantiation 45 | if self.__output_dim is None: 46 | self.__output_dim = self._submodules[0](xs[0].unsqueeze(0)).shape[1:] 47 | # initializing the "termination" mask 48 | mask = torch.ones(batch_size, dtype=torch.uint8, device=xs.device) \ 49 | if mask is None else mask 50 | # parallelizing this loop does not work. however, we can split the batch by the actions 51 | # creating the target variable 52 | ys = torch.zeros((batch_size, *self.__output_dim), dtype=torch.float, device=xs.device) 53 | for i in torch.arange(actions.max() + 1, device=xs.device): 54 | if i not in actions: 55 | continue 56 | # computing the mask as the currently active action on the active trajectories 57 | m = ((actions == i) * mask) 58 | if not any(m): 59 | continue 60 | ys[m] = self._submodules[i](xs[m]) 61 | if self._logging_selections: 62 | self._selection_log += actions.reshape(-1).cpu().tolist() 63 | if self._store_module_pointers: 64 | for mx, a in zip(mxs, actions): 65 | mx.append('selected_modules', self._submodules[a]) 66 | return ys, mxs, actions 67 | 68 | def start_logging_selections(self): 69 | self._logging_selections = True 70 | 71 | def stop_logging_and_get_selections(self, add_to_old=False): 72 | self._logging_selections = False 73 | logs = list(set([int(s) for s in self._selection_log])) 74 | del self._selection_log[:] 75 | self.last_selection_freeze = logs + self.last_selection_freeze if add_to_old else logs 76 | return self.last_selection_freeze 77 | -------------------------------------------------------------------------------- /PytorchRouting/CoreLayers/__init__.py: -------------------------------------------------------------------------------- 1 | from .Initialization import Initialization 2 | from .Loss import Loss 3 | from .Selection import Selection 4 | -------------------------------------------------------------------------------- /PytorchRouting/DecisionLayers/Decision.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file defines class DecisionModule. 3 | 4 | @author: Clemens Rosenbaum :: cgbr@cs.umass.edu 5 | @created: 6/7/18 6 | """ 7 | import abc 8 | import copy 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | from torch.distributions.distribution import Distribution 14 | 15 | from .PolicyStorage import ApproxPolicyStorage, TabularPolicyStorage 16 | from PytorchRouting.RewardFunctions.PerAction.PerActionBaseReward import PerActionBaseReward 17 | 18 | 19 | class Decision(nn.Module, metaclass=abc.ABCMeta): 20 | """ 21 | Class DecisionModule defines the base class for all decision modules. 22 | """ 23 | 24 | def __init__( 25 | self, 26 | num_selections, 27 | in_features, 28 | num_agents=1, 29 | exploration=0.1, 30 | policy_storage_type='approx', 31 | detach=True, 32 | approx_hidden_dims=(), 33 | policy_net=None, 34 | additional_reward_func=PerActionBaseReward(), 35 | set_pg_temp=False, 36 | **kwargs 37 | ): 38 | nn.Module.__init__(self) 39 | self._in_features = in_features 40 | self._num_selections = num_selections 41 | self._num_agents = num_agents 42 | self._exploration = exploration 43 | self._detach = detach 44 | self._pol_type = policy_storage_type 45 | self._pol_hidden_dims = approx_hidden_dims 46 | self._policy = self._construct_policy_storage( 47 | self._num_selections, self._pol_type, policy_net, self._pol_hidden_dims) 48 | self.additional_reward_func = additional_reward_func 49 | self._dist_dim = 1 50 | self._set_pg_temp = set_pg_temp 51 | self._pg_temperature = 1. 52 | 53 | def set_exploration(self, exploration): 54 | self._exploration = exploration 55 | 56 | @abc.abstractmethod 57 | def _forward(self, xs, prior_action): 58 | return torch.zeros(1, 1), [], torch.zeros(1, 1) 59 | 60 | @staticmethod 61 | def _eval_stochastic_are_exp(actions, dist): 62 | if len(dist.shape) == 3: 63 | dist = dist[:, :, 0] 64 | return (torch.max(dist, dim=1)[1].view(-1) == actions.view(-1)).byte() 65 | 66 | @abc.abstractmethod 67 | def _forward(self, xs, prior_action): 68 | return torch.zeros(1, 1), [], torch.zeros(1, 1) 69 | 70 | @staticmethod 71 | def _loss(self, is_terminal, state, next_state, action, next_action, reward, cum_return, final_reward): 72 | pass 73 | 74 | def _construct_policy_storage(self, out_dim, policy_storage_type, approx_module, approx_hidden_dims, in_dim=None): 75 | in_dim = in_dim or self._in_features 76 | if approx_module is not None: 77 | policy = nn.ModuleList( 78 | [ApproxPolicyStorage(approx=copy.deepcopy(approx_module), detach=self._detach) 79 | for _ in range(self._num_agents)] 80 | ) 81 | elif policy_storage_type in ('approx', 0): 82 | policy = nn.ModuleList( 83 | [ApproxPolicyStorage( 84 | in_features=in_dim, 85 | num_selections=out_dim, 86 | hidden_dims=approx_hidden_dims, 87 | detach=self._detach) 88 | for _ in range(self._num_agents)] 89 | ) 90 | elif policy_storage_type in ('tabular', 1): 91 | policy = nn.ModuleList( 92 | [TabularPolicyStorage(num_selections=out_dim) 93 | for _ in range(self._num_agents)] 94 | ) 95 | else: 96 | raise ValueError(f'Policy storage type {policy_storage_type} not understood.') 97 | return policy 98 | 99 | def forward(self, xs, mxs, prior_actions=None, mask=None, update_target=None): 100 | """ 101 | The forward method of DecisionModule takes a batch of inputs, and a list of metainformation, and 102 | append the decision made to the metainformation objects. 103 | :param xs: 104 | :param mxs: 105 | :param prior_actions: prior actions that select the agent 106 | :param mask: a torch.ByteTensor that determines if the trajectory is active. if it is not, no action 107 | will be executed 108 | :param update_target: (only relevant for GumbelSoftmax) if specified, this will include the gradientflow 109 | in update_target, and will thus return update_target 110 | :return: xs OR update_target, if specified, with potentially an attached backward object 111 | """ 112 | # input checking 113 | assert len(xs) == len(mxs) 114 | batch_size = xs.size(0) 115 | assert self._num_agents == 1 or prior_actions is not None, \ 116 | 'Decision makers with more than one action have to have prior_actions provided.' 117 | assert mask is None or mask.max() == 1, \ 118 | 'Please check that a batch being passed in has at least one active (non terminated) trajectory.' 119 | # computing the termination mask and the prior actions if not passed in 120 | mask = torch.ones(batch_size, dtype=torch.uint8, device=xs.device) \ 121 | if mask is None else mask 122 | prior_actions = torch.zeros(batch_size, dtype=torch.long, device=xs.device) \ 123 | if prior_actions is None or len(prior_actions) == 0 else prior_actions.reshape(-1) 124 | ys = xs.clone() if update_target is None else update_target.clone() # required as in-place ops follow 125 | # initializing the return vars 126 | actions = torch.zeros(batch_size, dtype=torch.long, device=xs.device) 127 | are_exp = torch.zeros(batch_size, dtype=torch.uint8, device=xs.device) 128 | dists = torch.zeros((batch_size, self._num_selections, 5), device=xs.device) 129 | # "clustering" by agent 130 | for i in torch.arange(0, prior_actions.max() + 1, device=xs.device): 131 | if i not in prior_actions: 132 | continue 133 | # computing the mask as the currently computed agent on the active trajectories 134 | m = ((prior_actions == i) * mask) 135 | if not any(m): 136 | continue 137 | # selecting the actions 138 | y, a, e, d = self._forward(xs[m], i) 139 | # merging the results 140 | ys[m], actions[m], are_exp[m], dists[m, :, :d.size(-1)] = \ 141 | y, a.view(-1), e.view(-1), d.view(d.size(0), d.size(1), -1) 142 | actions = actions.view(-1) # flattens the actions tensor, but does not produce a scalar 143 | assert len(actions) == len(are_exp) == dists.size(0) == len(mxs) 144 | # amending the metas 145 | for ia, a, e, d, mx in zip(mask, actions, are_exp, dists.split(1, dim=0), mxs): 146 | if ia: 147 | mx.append('actions', a, new_step=True) 148 | mx.append('is_exploratory', e.squeeze()) 149 | mx.append('states', d) 150 | mx.append('loss_funcs', self._loss) 151 | mx.append('reward_func', self.additional_reward_func) 152 | self.additional_reward_func.register(d, a) 153 | return ys, mxs, actions 154 | 155 | -------------------------------------------------------------------------------- /PytorchRouting/DecisionLayers/Others/GumbelSoftmax.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file defines class GumbelSoftmax. 3 | 4 | @author: Clemens Rosenbaum :: cgbr@cs.umass.edu 5 | @created: 6/12/18 6 | """ 7 | import math 8 | import torch 9 | import torch.nn as nn 10 | from torch.autograd import Variable 11 | 12 | from ..Decision import Decision 13 | 14 | 15 | class GumbelSoftmax(Decision): 16 | """ 17 | Class GumbelSoftmax defines a decision making procedure that uses the GumbelSoftmax reparameterization trick 18 | to perform differentiable sampling from the categorical distribution. 19 | """ 20 | def __init__(self, *args, **kwargs): 21 | Decision.__init__(self, *args, **kwargs) 22 | # translating exploration into the sampling temperature parameter in [0.1, 10] 23 | self._gumbel_softmax = GumbelSoftmaxSampling(temperature_init=1) 24 | self.set_exploration(self._exploration) 25 | 26 | def set_exploration(self, exploration): 27 | temperature = 0.1 + 9.9*exploration 28 | self._gumbel_softmax.set_temperature(temperature) 29 | 30 | @staticmethod 31 | def _loss(self, is_terminal, state, next_state, action, next_action, reward, final_reward): 32 | return torch.zeros_like(action, dtype=torch.float, device=action.device) 33 | 34 | def _forward(self, xs, agent): 35 | logits = self._policy[agent](xs) 36 | if self.training: 37 | actions, multiples = self._gumbel_softmax.sample(logits) 38 | else: 39 | actions = logits.max(dim=1)[1] 40 | multiples = 1. 41 | # shape casting to allow for mask-multiply 42 | ys = (xs.contiguous().view(xs.size(0), -1) * multiples).contiguous().view(xs.shape) 43 | return ys, actions, self._eval_stochastic_are_exp(actions, logits), logits 44 | 45 | 46 | class GumbelSoftmaxSampling(nn.Module): 47 | """ 48 | This class defines the core functionality to sample from a gumbel softmax distribution 49 | """ 50 | 51 | def __init__(self, temperature_init=1., hard=True, hook=None): 52 | nn.Module.__init__(self) 53 | self._temperature = temperature_init 54 | self.softmax = nn.Softmax(dim=1) 55 | self._hard = hard 56 | self._hook = hook 57 | 58 | def set_temperature(self, temperature): 59 | self._temperature = temperature 60 | # print('The new temperature param is: {}'.format(self._temperature)) 61 | 62 | @staticmethod 63 | def _sample_gumble(shape, eps=1e-20): 64 | U = torch.FloatTensor(*shape) 65 | U.uniform_(0, 1) 66 | logs = -torch.log(-torch.log(U + eps) + eps) 67 | return logs 68 | 69 | def _gumbel_softmax_sample(self, logits): 70 | y = logits + Variable(self._sample_gumble(logits.size())).to(logits.device) 71 | dist = self.softmax(y / self._temperature) 72 | if self._hook is not None: 73 | dist.register_hook(self._hook) 74 | return dist 75 | 76 | def forward(self, logits): 77 | y = self._gumbel_softmax_sample(logits) 78 | if self._hard: 79 | _, y_hard_index = torch.max(y, len(y.size())-1) 80 | y_hard = y.clone().data.zero_() 81 | y_hard[0, y_hard_index.squeeze()] = 1. 82 | y_no_grad = y.detach() 83 | y = y_hard - y_no_grad + y 84 | return y 85 | 86 | def sample(self, logits): 87 | y = self._gumbel_softmax_sample(logits) 88 | _, y_hard_index = torch.max(y, dim=-1) 89 | index = y_hard_index.detach().view(-1, 1) 90 | y_fake_grad = y - y.detach() 91 | # if len(y_fake_grad.shape) == 1: 92 | # y_fake_grad = y_fake_grad.view(1, -1) 93 | multiplier = 1 + torch.gather(y_fake_grad, 1, index) 94 | return index, multiplier 95 | 96 | def __call__(self, *args, **kwargs): 97 | return self.forward(*args, **kwargs) 98 | -------------------------------------------------------------------------------- /PytorchRouting/DecisionLayers/Others/PassThrough.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file defines the pass through decision maker. 3 | 4 | @author: Clemens Rosenbaum :: cgbr@cs.umass.edu 5 | @created: 2/7/19 6 | """ 7 | import torch 8 | 9 | from ..Decision import Decision 10 | 11 | 12 | class PassThrough(Decision): 13 | """ 14 | This helper decision module does not actually make any decision, but is only a dummy useful for some 15 | implementations. 16 | """ 17 | def __init__(self, *args, **kwargs): 18 | Decision.__init__(self, None, None, ) 19 | 20 | @staticmethod 21 | def _loss(self, is_terminal, state, next_state, action, next_action, reward, cum_return, final_reward): 22 | return torch.zeros(1).to(action.device) 23 | 24 | def _construct_policy_storage(self, _1, _2, _3, _4): 25 | return [] 26 | 27 | def _forward(self, xs, prior_action): pass 28 | 29 | def forward(self, xs, mxs, _=None, __=None): 30 | return xs, mxs, torch.zeros(len(mxs), dtype=torch.long, device=xs.device) 31 | -------------------------------------------------------------------------------- /PytorchRouting/DecisionLayers/Others/PerTaskAssignment.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file defines class PerTaskAssignment. 3 | 4 | @author: Clemens Rosenbaum :: cgbr@cs.umass.edu 5 | @created: 6/12/18 6 | """ 7 | """ 8 | This file defines class REINFORCE. 9 | 10 | @author: Clemens Rosenbaum :: cgbr@cs.umass.edu 11 | @created: 6/7/18 12 | """ 13 | import torch 14 | 15 | from ..Decision import Decision 16 | 17 | 18 | class PerTaskAssignment(Decision): 19 | """ 20 | This simple class translates task assignments stored in the meta-information objects to actions. 21 | """ 22 | def __init__(self, *args, **kwargs): 23 | Decision.__init__(self, None, None, ) 24 | 25 | @staticmethod 26 | def _loss(self, is_terminal, state, next_state, action, next_action, reward, cum_return, final_reward): 27 | return torch.zeros(1).to(action.device) 28 | 29 | def _construct_policy_storage(self, _1, _2, _3, _4): 30 | return [] 31 | 32 | def _forward(self, xs, prior_action): pass 33 | 34 | def forward(self, xs, mxs, _=None, __=None): 35 | actions = torch.LongTensor([m.task for m in mxs]).to(xs.device) 36 | return xs, mxs, actions 37 | -------------------------------------------------------------------------------- /PytorchRouting/DecisionLayers/Others/RELAX.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file defines class REBAR. Implementation largely taken from: 3 | https://github.com/duvenaud/relax/blob/master/pytorch_toy.py 4 | 5 | @author: Clemens Rosenbaum :: cgbr@cs.umass.edu 6 | @created: 6/12/18 7 | """ 8 | import math 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | from torch.autograd import grad as compute_gradient 14 | 15 | from ..Decision import Decision 16 | 17 | 18 | def entropy(probs): 19 | return torch.sum(- torch.log(probs) * probs, dim=1) 20 | 21 | 22 | def make_samples(logits, temp, eps=1e-8): 23 | u1 = torch.zeros(*logits.shape, device=logits.device) 24 | u2 = torch.zeros(*logits.shape, device=logits.device) 25 | u1.uniform_() 26 | u2.uniform_() 27 | # temp = tf.exp(log_temp) 28 | # logprobs = tf.nn.log_softmax(logits) 29 | logprobs = torch.distributions.Categorical(logits=logits).logits 30 | g = -torch.log(-torch.log(u1 + eps) + eps) 31 | scores = logprobs_z = logprobs + g 32 | hard_samples = torch.argmax(scores, dim=1) 33 | # hard_samples_oh = tf.one_hot(hard_samples, scores.get_shape().as_list()[1]) 34 | hard_samples_onehot = torch.zeros(hard_samples.size(0), scores.size(1), device=logits.device) 35 | hard_samples_onehot.scatter_(1, hard_samples.unsqueeze(1), 1) 36 | 37 | g2 = -torch.log(-torch.log(u2 + eps) + eps) 38 | scores2 = logprobs + g2 39 | 40 | # B = tf.reduce_sum(scores2 * hard_samples_oh, axis=1, keep_dims=True) - logprobs 41 | B = scores2 * hard_samples_onehot - logprobs 42 | y = -1. * torch.log(u2) + torch.exp(-1. * B) 43 | g3 = -1. * torch.log(y) 44 | scores3 = g3 + logprobs 45 | # slightly biased… 46 | logprobs_zt = hard_samples_onehot * scores2 + ((-1. * hard_samples_onehot) + 1.) * scores3 47 | return hard_samples, F.softmax(logprobs_z / temp, dim=1), F.softmax(logprobs_zt / temp, dim=1) 48 | 49 | 50 | class RELAX(Decision): 51 | """ 52 | Class GumbelSoftmax defines a decision making procedure that uses the GumbelSoftmax reparameterization trick 53 | to perform differentiable sampling from the categorical distribution. 54 | """ 55 | def __init__(self, *args, value_net=None, **kwargs): 56 | Decision.__init__(self, *args, **kwargs) 57 | # translating exploration into the sampling temperature parameter in [0.1, 10] 58 | 59 | self._value_mem = self._construct_policy_storage( 60 | 1, self._pol_type, value_net, self._pol_hidden_dims, in_dim=self._in_features + self._num_selections) 61 | self._temperature = 0.5 62 | self._value_coefficient = 0.5 63 | self._entropy_coefficient = 0.01 64 | 65 | def set_exploration(self, exploration): 66 | self._temperature = 0.1 + 9.9*exploration 67 | 68 | def _loss(self, is_terminal, state, next_state, action, next_action, reward, cum_return, final_reward): 69 | if not self.training: 70 | # we cannot compute gradients in test mode, so we cannot compute the respective losses 71 | return torch.zeros(state.size(0), device=state.device) 72 | # oh_A = tf.one_hot(train_model.a0, ac_space.n) 73 | onehot_action = torch.zeros(state.size(0), state.size(1), device=state.device) 74 | onehot_action.scatter_(1, action.unsqueeze(1), 1) 75 | 76 | log_policy = state[:, :, 0] 77 | values = state[:, :, 1] 78 | values_t = state[:, :, 2] 79 | policy = F.softmax(log_policy, dim=1) 80 | policy_entropy = entropy(policy) 81 | 82 | # params = find_trainable_variables("model") 83 | params = self.parameters() 84 | # policy_params = [v for v in params if "pi" in v.name] 85 | policy_params = list(self._policy.parameters()) 86 | # vf_params = [v for v in params if "vf" in v.name] 87 | vf_params = list(self._value_mem.parameters()) 88 | # entropy_grads = tf.gradients(entropy, policy_params) 89 | entropy_grads = compute_gradient(policy_entropy.sum(), policy_params, retain_graph=True, allow_unused=True) 90 | 91 | 92 | # ddiff_loss = tf.reduce_sum(train_model.vf - train_model.vf_t) 93 | # ddiff_grads = tf.gradients(ddiff_loss, policy_params) 94 | 95 | ddiff_loss = (values - values_t).sum() 96 | ddiff_grads = compute_gradient(ddiff_loss, policy_params, 97 | retain_graph=True, only_inputs=True, create_graph=True, allow_unused=True) 98 | 99 | # sm = tf.nn.softmax(train_model.pi) 100 | dlogp_dpi = onehot_action * (1. - policy) + (1. - onehot_action) * (-policy) 101 | # pi_grads = -((tf.expand_dims(R, 1) - train_model.vf_t) * dlogp_dpi) 102 | pi_grads = -((cum_return.unsqueeze(1).expand_as(values_t) - values_t) * dlogp_dpi) 103 | # pg_grads = tf.gradients(train_model.pi, policy_params, grad_ys=pi_grads) 104 | pg_grads = compute_gradient(policy, policy_params, grad_outputs=pi_grads, 105 | retain_graph=True, create_graph=True, allow_unused=True) 106 | pg_grads = [pg - dg for pg, dg in zip(pg_grads, ddiff_grads) if pg is not None] 107 | 108 | # cv_grads = tf.concat([tf.reshape(p, [-1]) for p in pg_grads], 0) 109 | cv_grads = torch.cat([p.view(-1) for p in pg_grads], 0) 110 | cv_grad_splits = torch.pow(cv_grads, 2).sum() 111 | vf_loss = cv_grad_splits * self._value_coefficient 112 | 113 | 114 | for e_grad, p_grad, param in zip(entropy_grads, pg_grads, policy_params): 115 | if p_grad is None and e_grad is None: 116 | continue 117 | elif p_grad is None: 118 | p_grad = torch.zeros_like(e_grad) 119 | elif e_grad is None: 120 | e_grad = torch.zeros_like(p_grad) 121 | grad = -e_grad * self._entropy_coefficient + p_grad 122 | if param.grad is not None: 123 | param.grad.add_(grad) 124 | else: 125 | grad = grad.detach() 126 | grad.requires_grad_(False) 127 | param.grad = grad 128 | 129 | # cv_grads = compute_gradient(vf_loss, vf_params) 130 | # for cv_grad, param in zip(cv_grads, vf_params): 131 | # if param.grad is not None: 132 | # param.grad.add_(grad) 133 | # else: 134 | # grad = grad.detach() 135 | # grad.requires_grad_(False) 136 | # param.grad = grad 137 | vf_loss.backward() 138 | return torch.zeros(state.size(0), device=state.device) 139 | 140 | def _forward(self, xs, agent): 141 | logits = self._policy[agent](xs) 142 | a0, s0, st0 = make_samples(logits, self._temperature) 143 | values = self._value_mem[agent](torch.cat([xs, s0], dim=1)) 144 | values_t = self._value_mem[agent](torch.cat([xs, st0], dim=1)) 145 | if self.training: 146 | actions = a0 147 | else: 148 | actions = logits.max(dim=1)[1] 149 | state = torch.stack([logits, values.expand_as(logits), values_t.expand_as(logits)], dim=2) 150 | return xs, actions, self._eval_stochastic_are_exp(actions, logits), state 151 | 152 | -------------------------------------------------------------------------------- /PytorchRouting/DecisionLayers/Others/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cle-ros/RoutingNetworks/0f1fe1221c67a224a02bca6247d3c4488ede0a04/PytorchRouting/DecisionLayers/Others/__init__.py -------------------------------------------------------------------------------- /PytorchRouting/DecisionLayers/PolicyStorage.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file defines class ApproxPolicyStorageDecisionModule. 3 | 4 | @author: Clemens Rosenbaum :: cgbr@cs.umass.edu 5 | @created: 6/7/18 6 | """ 7 | import abc 8 | import torch 9 | import torch.nn as nn 10 | from torch.autograd import Variable 11 | from PytorchRouting.Helpers.MLP import MLP 12 | 13 | 14 | class ApproxPolicyStorage(nn.Module, metaclass=abc.ABCMeta): 15 | """ 16 | Class ApproxPolicyStorage defines a simple module to store a policy approximator. 17 | """ 18 | def __init__(self, approx=None, in_features=None, num_selections=None, hidden_dims=(), detach=True): 19 | nn.Module.__init__(self) 20 | self._detach = detach 21 | if approx: 22 | self._approx = approx 23 | else: 24 | self._approx = MLP( 25 | in_features, 26 | num_selections, 27 | hidden_dims 28 | ) 29 | 30 | def forward(self, xs): 31 | if self._detach: 32 | xs = xs.detach() 33 | policies = self._approx(xs) 34 | return policies 35 | 36 | 37 | class TabularPolicyStorage(nn.Module, metaclass=abc.ABCMeta): 38 | """ 39 | Class TabularPolicyStorage defines a simple module to store a policy in tabular form. 40 | """ 41 | def __init__(self, approx=None, in_features=None, num_selections=None, hidden_dims=()): 42 | nn.Module.__init__(self) 43 | self._approx = nn.Parameter( 44 | torch.ones(1, num_selections).float()/num_selections 45 | ) 46 | 47 | def forward(self, xs): 48 | policies = torch.cat([self._approx] * xs.shape[0], dim=0) 49 | return policies 50 | -------------------------------------------------------------------------------- /PytorchRouting/DecisionLayers/ReinforcementLearning/AAC.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file defines class REINFORCE. 3 | 4 | @author: Clemens Rosenbaum :: cgbr@cs.umass.edu 5 | @created: 6/7/18 6 | """ 7 | import copy 8 | import torch 9 | import torch.nn.functional as F 10 | 11 | from ..Decision import Decision 12 | 13 | 14 | class AAC(Decision): 15 | """ 16 | ActorCritic based decision making. 17 | """ 18 | def __init__(self, *args, value_net=None, **kwargs): 19 | Decision.__init__(self, *args, **kwargs) 20 | self._value_mem = self._construct_policy_storage( 21 | 1, self._pol_type, value_net, self._pol_hidden_dims) 22 | 23 | def _loss(self, is_terminal, state, next_state, action, next_action, reward, cum_return, final_reward): 24 | advantages = (cum_return - state[:, 0, 1]) 25 | act_loss = - state[:, :, 0].gather(index=action.unsqueeze(1), dim=1).view(-1) * advantages.detach() 26 | val_loss = advantages.pow(2).mean() 27 | return act_loss + val_loss 28 | 29 | def _forward(self, xs, agent): 30 | policy = self._policy[agent](xs) 31 | values = self._value_mem[agent](xs).expand_as(policy) 32 | distribution = torch.distributions.Categorical(logits=policy) 33 | if self.training: 34 | actions = distribution.sample() 35 | else: 36 | actions = distribution.logits.max(dim=1)[1] 37 | state = torch.stack([distribution.logits, values], 2) 38 | return xs, actions, self._eval_stochastic_are_exp(actions, state), state 39 | 40 | 41 | class BootstrapAAC(AAC): 42 | """ 43 | ActorCritic based decision making. 44 | """ 45 | def _loss(self, is_terminal, state, next_state, action, next_action, reward, cum_return, final_reward): 46 | next_state = torch.where(is_terminal, torch.zeros_like(next_state, device=state.device), next_state) 47 | advantages = (next_state[:, 0, 1] + reward - state[:, 0, 1]) 48 | act_loss = - state[:, :, 0].gather(index=action.unsqueeze(1), dim=1).view(-1) * advantages.detach() 49 | val_loss = advantages.pow(2).mean() 50 | return act_loss + val_loss 51 | 52 | 53 | class EGreedyAAC(AAC): 54 | def _loss(self, is_terminal, state, next_state, action, next_action, reward, cum_return, final_reward): 55 | importance_weights = (state[:, :, 1] / state[:, :, 2]).gather(1, action.unsqueeze(1)) 56 | advantages = (cum_return.detach() - state[:, 0, 3]) 57 | importance_weighted_advantages = (importance_weights * advantages).detach() 58 | act_loss = - state[:, :, 0].gather(index=action.unsqueeze(1), dim=1).view(-1) * importance_weighted_advantages 59 | val_loss = advantages.pow(2).mean() 60 | return act_loss + val_loss 61 | 62 | def _forward(self, xs, agent): 63 | batch_dim = xs.size(0) 64 | policy = self._policy[agent](xs) 65 | values = self._value_mem[agent](xs).expand_as(policy) 66 | distribution = torch.distributions.Categorical(logits=policy) 67 | if self.training: 68 | exploration_dist = torch.ones(batch_dim, 2).float() 69 | exploration_dist[:, 0] *= 1 - self._exploration 70 | exploration_dist[:, 1] *= self._exploration 71 | explore_bin = torch.multinomial(exploration_dist, 1).byte().to(xs.device) 72 | selected_probs, greedy = distribution.logits.max(dim=1) 73 | on_policy = distribution.sample() 74 | actions = torch.where(explore_bin, on_policy.unsqueeze(-1), greedy.unsqueeze(-1)) 75 | 76 | # computing the importance weights 77 | sampling_dist = distribution.probs.detach() * \ 78 | (1 / (1 - selected_probs.unsqueeze(1).expand_as(policy))) * \ 79 | (self._exploration / (policy.size(1) - 1)) 80 | sampling_dist.scatter_(1, 81 | greedy.unsqueeze(1), 82 | torch.ones(batch_dim, 1, device=xs.device) * (1 - self._exploration)) 83 | state = torch.stack((distribution.logits, distribution.probs, sampling_dist, values), dim=2) 84 | else: 85 | actions = distribution.logits.max(dim=1)[1] 86 | state = torch.stack((distribution.logits, distribution.probs, distribution.probs, values), dim=2) 87 | return xs, actions, self._eval_stochastic_are_exp(actions, state), state 88 | -------------------------------------------------------------------------------- /PytorchRouting/DecisionLayers/ReinforcementLearning/ActorCritic.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file defines class REINFORCE. 3 | 4 | @author: Clemens Rosenbaum :: cgbr@cs.umass.edu 5 | @created: 6/7/18 6 | """ 7 | import copy 8 | import torch 9 | import torch.nn.functional as F 10 | 11 | from .REINFORCE import REINFORCE 12 | 13 | 14 | class ActorCritic(REINFORCE): 15 | """ 16 | ActorCritic based decision making. 17 | """ 18 | def __init__(self, *args, qvalue_net=None, **kwargs): 19 | REINFORCE.__init__(self, *args, **kwargs) 20 | if qvalue_net is None and 'policy_net' in kwargs: 21 | qvalue_net = copy.deepcopy(kwargs['policy_net']) 22 | self._qvalue_mem = self._construct_policy_storage( 23 | self._num_selections, self._pol_type, qvalue_net, self._pol_hidden_dims) 24 | 25 | def _loss(self, is_terminal, state, next_state, action, next_action, reward, cum_return, final_reward): 26 | normalized_return = (cum_return - state[:, :, 1].gather(index=action.unsqueeze(1), dim=1).view(-1)).detach() 27 | act_loss = - state[:, :, 0].gather(index=action.unsqueeze(1), dim=1).view(-1) * normalized_return 28 | value_target = torch.where(is_terminal, final_reward, next_state[:, :, 1].max(dim=1)[0] - reward).detach() 29 | val_loss = F.mse_loss(state[:, :, 1].gather(index=action.unsqueeze(1), dim=1).view(-1), 30 | value_target, reduction='none').view(-1) 31 | return act_loss + val_loss 32 | 33 | def _forward(self, xs, agent): 34 | policy = self._policy[agent](xs) 35 | values = self._qvalue_mem[agent](xs) 36 | distribution = torch.distributions.Categorical(logits=policy/self._pg_temperature) 37 | if self.training: 38 | actions = distribution.sample() 39 | else: 40 | actions = distribution.logits.max(dim=1)[1] 41 | state = torch.stack([distribution.logits, values], 2) 42 | return xs, actions, self._eval_stochastic_are_exp(actions, state), state 43 | -------------------------------------------------------------------------------- /PytorchRouting/DecisionLayers/ReinforcementLearning/AdvantageLearning.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file defines class REINFORCE. 3 | 4 | @author: Clemens Rosenbaum :: cgbr@cs.umass.edu 5 | @created: 6/7/18 6 | """ 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | from ..Decision import Decision 11 | 12 | 13 | class AdvantageLearning(Decision): 14 | """ 15 | QLearning (state-action value function) based decision making. 16 | """ 17 | def __init__(self, *args, value_net=None, **kwargs): 18 | Decision.__init__(self, *args, **kwargs) 19 | self._value_mem = self._construct_policy_storage( 20 | 1, self._pol_type, value_net, self._pol_hidden_dims) 21 | 22 | def _loss(self, is_terminal, state, next_state, action, next_action, reward, cum_return, final_reward): 23 | value_loss = F.mse_loss(state[:, 0, 1], cum_return, reduction='none').view(-1) 24 | qval_loss = F.mse_loss(state[:, :, 0].gather(index=action.unsqueeze(1), dim=1).view(-1), 25 | cum_return - state[:, 0, 1].detach(), reduction='none') 26 | return value_loss + qval_loss 27 | 28 | def _forward(self, xs, agent): 29 | batch_dim = xs.size()[0] 30 | qvals = self._policy[agent](xs) 31 | value = self._value_mem[agent](xs) 32 | value = value.expand_as(qvals) 33 | exploration_dist = torch.ones(batch_dim, 2).float() 34 | exploration_dist[:, 0] *= 1-self._exploration 35 | exploration_dist[:, 1] *= self._exploration 36 | explore_bin = torch.multinomial(exploration_dist, 1).byte().to(xs.device) 37 | _, greedy = qvals.max(dim=1) 38 | if self.training: 39 | explore = torch.randint(low=0, high=qvals.size()[1], size=(batch_dim, 1)).to(xs.device).long() 40 | actions = torch.where(explore_bin, explore, greedy.unsqueeze(-1)) 41 | else: 42 | actions = greedy 43 | state = torch.stack([qvals, value], 2) 44 | return xs, actions, explore_bin, state 45 | 46 | 47 | class BootstrapAdvantageLearning(AdvantageLearning): 48 | def _loss(self, is_terminal, state, next_state, action, next_action, reward, cum_return, final_reward): 49 | vtarget = torch.where(is_terminal, final_reward, next_state[:, 0, 1] + reward).detach() 50 | value_loss = F.mse_loss(state[:, 0, 1], vtarget, reduction='none').view(-1) 51 | atarget = (vtarget - state[:, 0, 1]).detach() 52 | adv_loss = F.mse_loss(state[:, :, 0].gather(index=action.unsqueeze(1), dim=1).view(-1), 53 | atarget, reduction='none') 54 | return value_loss + adv_loss 55 | 56 | 57 | class SurpriseLearning(AdvantageLearning): 58 | def _loss(self, is_terminal, state, next_state, action, next_action, reward, cum_return, final_reward): 59 | target = torch.where(is_terminal, final_reward, next_state[:, 0, 1] - reward).detach() 60 | value_loss = F.mse_loss(state[:, 0, 1], target, reduction='none').view(-1) 61 | qval_loss = F.mse_loss(state[:, :, 0].gather(index=action.unsqueeze(1), dim=1).view(-1), 62 | target - state[:, 0, 1].detach(), reduction='none') 63 | return value_loss + qval_loss 64 | -------------------------------------------------------------------------------- /PytorchRouting/DecisionLayers/ReinforcementLearning/QLearning.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file defines class REINFORCE. 3 | 4 | @author: Clemens Rosenbaum :: cgbr@cs.umass.edu 5 | @created: 6/7/18 6 | """ 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | from ..Decision import Decision 11 | 12 | 13 | class QLearning(Decision): 14 | """ 15 | QLearning (state-action value function) based decision making. 16 | """ 17 | 18 | def _loss(self, is_terminal, state, next_state, action, next_action, reward, cum_return, final_reward): 19 | target = torch.where(is_terminal, final_reward, next_state[:, :, 0].max(dim=1)[0] - reward).detach() 20 | return F.mse_loss(state[:, :, 0].gather(index=action.unsqueeze(1), dim=1).view(-1), 21 | target.view(-1), reduction='none') 22 | 23 | def _forward(self, xs, agent): 24 | batch_dim = xs.size()[0] 25 | policy = self._policy[agent](xs) 26 | exploration_dist = torch.ones(batch_dim, 2).float() 27 | exploration_dist[:, 0] *= 1-self._exploration 28 | exploration_dist[:, 1] *= self._exploration 29 | explore_bin = torch.multinomial(exploration_dist, 1).byte().to(xs.device) 30 | _, greedy = policy.max(dim=1) 31 | if self.training: 32 | explore = torch.randint(low=0, high=policy.size()[1], size=(batch_dim, 1)).to(xs.device).long() 33 | actions = torch.where(explore_bin, explore, greedy.unsqueeze(-1)) 34 | else: 35 | actions = greedy 36 | return xs, actions, explore_bin, policy 37 | -------------------------------------------------------------------------------- /PytorchRouting/DecisionLayers/ReinforcementLearning/REINFORCE.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file defines class REINFORCE. 3 | 4 | @author: Clemens Rosenbaum :: cgbr@cs.umass.edu 5 | @created: 6/7/18 6 | """ 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | from ..Decision import Decision 11 | 12 | 13 | class REINFORCE(Decision): 14 | """ 15 | REINFORCE (likelihood ratio policy gradient) based decision making. 16 | """ 17 | 18 | def _loss(self, is_terminal, state, next_state, action, next_action, reward, cum_return, final_reward): 19 | return - state[:, :, 0].gather(index=action.unsqueeze(1), dim=1).view(-1) * cum_return.detach() 20 | 21 | def set_exploration(self, exploration): 22 | if self._set_pg_temp: 23 | self._pg_temperature = 0.1 + 9.9*exploration 24 | 25 | def _forward(self, xs, agent): 26 | policy = self._policy[agent](xs) 27 | distribution = torch.distributions.Categorical(logits=policy/self._pg_temperature) 28 | if self.training: 29 | actions = distribution.sample() 30 | else: 31 | actions = distribution.logits.max(dim=1)[1] 32 | return xs, actions, self._eval_stochastic_are_exp(actions, distribution.logits), distribution.logits 33 | 34 | 35 | class REINFORCEBl1(Decision): 36 | def __init__(self, *args, value_net=None, **kwargs): 37 | Decision.__init__(self, *args, **kwargs) 38 | self._value_mem = self._construct_policy_storage( 39 | self._num_selections, self._pol_type, value_net, self._pol_hidden_dims) 40 | 41 | def _loss(self, is_terminal, state, next_state, action, next_action, reward, cum_return, final_reward): 42 | normalized_return = (cum_return - state[:, 0, 1].view(-1)).detach() 43 | act_loss = - state[:, :, 0].gather(index=action.unsqueeze(1), dim=1).view(-1) * normalized_return 44 | value_target = torch.where(is_terminal, final_reward, next_state[:, :, 1].max(dim=1)[0] - reward).detach() 45 | val_loss = F.mse_loss(state[:, :, 1].gather(index=action.unsqueeze(1), dim=1).view(-1), 46 | value_target, reduction='none').view(-1) 47 | return act_loss + val_loss 48 | 49 | def _forward(self, xs, agent): 50 | policy = self._policy[agent](xs) 51 | values = self._value_mem[agent](xs).expand_as(policy) 52 | distribution = torch.distributions.Categorical(logits=policy) 53 | if self.training: 54 | actions = distribution.sample() 55 | else: 56 | actions = distribution.logits.max(dim=1)[1] 57 | state = torch.stack([distribution.logits, values], 2) 58 | return xs, actions, self._eval_stochastic_are_exp(actions, state), state 59 | 60 | 61 | class REINFORCEBl2(REINFORCEBl1): 62 | def _loss(self, is_terminal, state, next_state, action, next_action, reward, cum_return, final_reward): 63 | logits = state[:, :, 0].gather(index=action.unsqueeze(1), dim=1).view(-1) 64 | probs = F.softmax(state[:, :, 0], dim=1).gather(index=action.unsqueeze(1), dim=1).view(-1) 65 | normalized_return = (cum_return/(self._num_selections * probs) - state[:, 0, 1].view(-1)).detach() 66 | act_loss = - logits * normalized_return 67 | value_target = torch.where(is_terminal, final_reward, next_state[:, :, 1].max(dim=1)[0] - reward).detach() 68 | val_loss = F.mse_loss(state[:, :, 1].gather(index=action.unsqueeze(1), dim=1).view(-1), 69 | value_target, reduction='none').view(-1) 70 | return act_loss + val_loss 71 | 72 | 73 | class EGreedyREINFORCE(REINFORCE): 74 | 75 | def set_exploration(self, exploration): 76 | # because of the special nature of this approach, exploration needs to be calculated differently 77 | self._exploration = min(1., 3.*exploration) 78 | 79 | def _loss(self, is_terminal, state, next_state, action, next_action, reward, cum_return, final_reward): 80 | importance_weights = (state[:, :, 1] / state[:, :, 2]).gather(1, action.unsqueeze(1)) 81 | importance_weighted_return = (importance_weights * cum_return).detach() 82 | return - state[:, :, 0].gather(dim=1, index=action.unsqueeze(1)).view(-1) * importance_weighted_return 83 | 84 | def _forward(self, xs, agent): 85 | batch_dim = xs.size(0) 86 | policy = self._policy[agent](xs) 87 | distribution = torch.distributions.Categorical(logits=policy) 88 | if self.training: 89 | exploration_dist = torch.ones(batch_dim, 2).float() 90 | exploration_dist[:, 0] *= 1 - self._exploration 91 | exploration_dist[:, 1] *= self._exploration 92 | explore_bin = torch.multinomial(exploration_dist, 1).byte().to(xs.device) 93 | selected_probs, greedy = distribution.logits.max(dim=1) 94 | on_policy = distribution.sample() 95 | actions = torch.where(explore_bin, on_policy.unsqueeze(-1), greedy.unsqueeze(-1)) 96 | 97 | # computing the importance weights 98 | sampling_dist = distribution.probs.detach() * \ 99 | (1 / (1 - selected_probs.unsqueeze(1).expand_as(policy))) * \ 100 | (self._exploration / (policy.size(1) - 1)) 101 | sampling_dist.scatter_(1, 102 | greedy.unsqueeze(1), 103 | torch.ones(batch_dim, 1, device=xs.device) * (1 - self._exploration)) 104 | state = torch.stack((distribution.logits, distribution.probs, sampling_dist), dim=2) 105 | else: 106 | actions = distribution.logits.max(dim=1)[1] 107 | state = torch.stack((distribution.logits, distribution.probs, distribution.probs), dim=2) 108 | return xs, actions, self._eval_stochastic_are_exp(actions, distribution.logits), state 109 | -------------------------------------------------------------------------------- /PytorchRouting/DecisionLayers/ReinforcementLearning/SARSA.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file defines class REINFORCE. 3 | 4 | @author: Clemens Rosenbaum :: cgbr@cs.umass.edu 5 | @created: 6/7/18 6 | """ 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | from .QLearning import QLearning 11 | 12 | 13 | class SARSA(QLearning): 14 | """ 15 | SARSA on-policy q-function learning. 16 | """ 17 | # target = torch.where(is_terminal, reward, next_state[:, :, 0].max(dim=1)[0] - reward) 18 | # target = target.detach() 19 | # return F.mse_loss(state[:, :, 0].gather(index=action.unsqueeze(1), dim=1).view(-1), 20 | # target.view(-1), reduction='none') 21 | 22 | def _loss(self, is_terminal, state, next_state, action, next_action, reward, cum_return, final_reward): 23 | target = torch.where(is_terminal, final_reward, 24 | state[:, :, 0].gather(index=next_action.unsqueeze(1), dim=1). 25 | view(-1)).detach() 26 | return F.mse_loss(state[:, :, 0].gather(index=action.unsqueeze(1), dim=1).view(-1), 27 | target.view(-1), reduction='none') 28 | -------------------------------------------------------------------------------- /PytorchRouting/DecisionLayers/ReinforcementLearning/WPL.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file defines class REINFORCE. 3 | 4 | @author: Clemens Rosenbaum :: cgbr@cs.umass.edu 5 | @created: 6/7/18 6 | """ 7 | import copy 8 | import torch 9 | import torch.nn.functional as F 10 | 11 | from .ActorCritic import ActorCritic 12 | 13 | 14 | class WPL(ActorCritic): 15 | """ 16 | Weighted Policy Learner (WPL) Multi-Agent Reinforcement Learning based decision making. 17 | """ 18 | 19 | def _loss(self, is_terminal, state, next_state, action, next_action, reward, cum_return, final_reward): 20 | grad_est = cum_return - state[:, :, 1].gather(index=action.unsqueeze(1), dim=1).view(-1) 21 | grad_projected = torch.where(grad_est < 0, 1. + grad_est, 2. - grad_est) 22 | prob_taken = state[:, :, 0].gather(index=action.unsqueeze(1), dim=1).view(-1) 23 | prob_target = (prob_taken * grad_projected).detach() 24 | act_loss = F.mse_loss(prob_taken, prob_target, reduction='none') 25 | ret_loss = F.mse_loss(state[:, :, 1].gather(index=action.unsqueeze(1), dim=1).view(-1), 26 | cum_return.detach(), reduction='none').view(-1) 27 | return act_loss + ret_loss 28 | 29 | def _forward(self, xs, agent): 30 | policy = self._policy[agent](xs) 31 | # policy = F.relu(policy) - F.relu(policy - 1.) + 1e-6 32 | policy = (policy.transpose(0, 1) - policy.min(dim=1)[0]).transpose(0, 1) + 1e-6 33 | # policy = policy/policy.sum(dim=1) 34 | values = self._qvalue_mem[agent](xs) 35 | distribution = torch.distributions.Categorical(probs=policy) 36 | if self.training: 37 | actions = distribution.sample() 38 | else: 39 | actions = distribution.logits.max(dim=1)[1] 40 | state = torch.stack([distribution.logits, values], 2) 41 | return xs, actions, self._eval_stochastic_are_exp(actions, state), state 42 | 43 | # @staticmethod 44 | # def _loss(sample): 45 | # grad_est = sample.cum_return - sample.state[:, 0, 1] 46 | # # ret_loss = F.smooth_l1_loss(sample.state[:, sample.action, 1], sample.cum_return).unsqueeze(-1) 47 | # ret_loss = F.smooth_l1_loss(sample.state[:, 0, 1], sample.cum_return).unsqueeze(-1) 48 | # grad_projected = grad_est * 1.3 49 | # grad_projected = torch.pow(grad_projected, 3.) 50 | # if grad_projected < 0: 51 | # pol_update = 1. + grad_projected 52 | # else: 53 | # pol_update = 2. - grad_projected 54 | # pol_update = sample.state[:, sample.action, 0] * pol_update 55 | # act_loss = F.smooth_l1_loss(sample.state[:, sample.action, 0], pol_update.data) 56 | # return act_loss + ret_loss 57 | 58 | # # @staticmethod 59 | # def _loss(self, sample): 60 | # grad_est = sample.cum_return - sample.state[:, 0, 1] 61 | # # ret_loss = F.smooth_l1_loss(sample.state[:, sample.action, 1], sample.cum_return).unsqueeze(-1) 62 | # # ret_loss = F.smooth_l1_loss(sample.state[:, 0, 1], sample.cum_return).unsqueeze(-1) 63 | # grad_projected = grad_est * 1.3 64 | # grad_projected = torch.pow(grad_projected, 3.) 65 | # if grad_projected < 0: 66 | # pol_update = 1. + grad_projected 67 | # else: 68 | # pol_update = 2. - grad_projected 69 | # pol_update = sample.state[:, sample.action, 0] * pol_update 70 | # self._policy[sample.prior_action]._approx.data[0, sample.action] = pol_update.data 71 | # self._qvalue_mem[sample.prior_action]._approx.data[0, sample.action] = \ 72 | # 0.9 * self._qvalue_mem[sample.prior_action]._approx.data[0, sample.action] + 0.1 * sample.cum_return 73 | # # act_loss = F.smooth_l1_loss(sample.state[:, sample.action, 0], pol_update.data) 74 | # return torch.zeros(1).to(sample.action.device) 75 | -------------------------------------------------------------------------------- /PytorchRouting/DecisionLayers/ReinforcementLearning/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cle-ros/RoutingNetworks/0f1fe1221c67a224a02bca6247d3c4488ede0a04/PytorchRouting/DecisionLayers/ReinforcementLearning/__init__.py -------------------------------------------------------------------------------- /PytorchRouting/DecisionLayers/__init__.py: -------------------------------------------------------------------------------- 1 | # RL approaches 2 | from .ReinforcementLearning.REINFORCE import REINFORCE, EGreedyREINFORCE, REINFORCEBl1, REINFORCEBl2 3 | from .ReinforcementLearning.QLearning import QLearning 4 | from .ReinforcementLearning.AdvantageLearning import AdvantageLearning 5 | from .ReinforcementLearning.SARSA import SARSA 6 | from .ReinforcementLearning.ActorCritic import ActorCritic 7 | from .ReinforcementLearning.AAC import AAC, BootstrapAAC, EGreedyAAC 8 | # MARL approaches 9 | from .ReinforcementLearning.WPL import WPL 10 | # Others 11 | from .Others.GumbelSoftmax import GumbelSoftmax 12 | from .Others.PerTaskAssignment import PerTaskAssignment 13 | from .Others.RELAX import RELAX 14 | from .Decision import Decision 15 | -------------------------------------------------------------------------------- /PytorchRouting/Examples/Datasets.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file defines the base class `Dataset` and classes for the MNIST and CIFAR100 MTL versions. 3 | As this is mostly for demonstration purposes, the code is uncommented. 4 | 5 | @author: Clemens Rosenbaum :: cgbr@cs.umass.edu 6 | @created: 6/14/18 7 | """ 8 | import abc 9 | import random 10 | try: 11 | import cPickle as pickle 12 | except ImportError: 13 | import pickle 14 | 15 | import numpy as np 16 | 17 | import torch 18 | from torch.autograd import Variable 19 | 20 | 21 | class Dataset(object, metaclass=abc.ABCMeta): 22 | """ 23 | Class Datasets defines ... 24 | """ 25 | 26 | def __init__(self, batch_size, data_files=(), cuda=False): 27 | self._iterator = None 28 | self._batch_size = batch_size 29 | self._data_files = data_files 30 | self._train_set, self._test_set = self._get_datasets() 31 | self._cuda = cuda 32 | 33 | @abc.abstractmethod 34 | def _get_datasets(self): return [], [] 35 | 36 | def _batched_iter(self, dataset, batch_size): 37 | for i in range(0, len(dataset), batch_size): 38 | batch = dataset[i:i+batch_size] 39 | samples = Variable(torch.stack([torch.FloatTensor(sample[0]) for sample in batch], 0)) 40 | targets = Variable(torch.stack([torch.LongTensor([sample[1]]) for sample in batch], 0)) 41 | if self._cuda: 42 | samples = samples.cuda() 43 | targets = targets.cuda() 44 | tasks = [sample[2] for sample in batch] 45 | yield samples, targets, tasks 46 | 47 | def get_batch(self): 48 | return next(self._iterator) 49 | 50 | def enter_train_mode(self): 51 | random.shuffle(self._train_set) 52 | self._iterator = self._batched_iter(self._train_set, self._batch_size) 53 | 54 | def enter_test_mode(self): 55 | self._iterator = self._batched_iter(self._test_set, self._batch_size) 56 | 57 | 58 | class CIFAR100MTL(Dataset): 59 | def __init__(self, *args, **kwargs): 60 | Dataset.__init__(self, *args, **kwargs) 61 | self.num_tasks = 20 62 | 63 | def _get_datasets(self): 64 | datasets = [] 65 | for fn in self._data_files: # assuming that the datafiles are [train_file_name, test_file_name] 66 | samples, labels, tasks = [], [], [] 67 | with open(fn, 'rb') as f: 68 | data_dict = pickle.load(f, encoding='latin1') 69 | samples += [np.resize(s, (3, 32, 32)) for s in data_dict['data']] 70 | tasks += [int(fl) for fl in data_dict['coarse_labels']] 71 | labels += [int(cl) % 5 for cl in data_dict['fine_labels']] 72 | datasets.append(list(zip(samples, labels, tasks))) 73 | train_set, test_set = datasets 74 | return train_set, test_set 75 | -------------------------------------------------------------------------------- /PytorchRouting/Examples/Models.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file defines class Models. 3 | 4 | @author: Clemens Rosenbaum :: cgbr@cs.umass.edu 5 | @created: 6/14/18 6 | """ 7 | try: 8 | import cPickle as pickle 9 | except ImportError: 10 | import pickle 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | from PytorchRouting.UtilLayers import Sequential 16 | 17 | from PytorchRouting.CoreLayers import Initialization, Loss, Selection 18 | from PytorchRouting.DecisionLayers import PerTaskAssignment 19 | from PytorchRouting.DecisionLayers.Decision import Decision 20 | from PytorchRouting.RewardFunctions.Final import CorrectClassifiedReward, NegLossReward 21 | from PytorchRouting.RewardFunctions.PerAction import CollaborationReward 22 | 23 | 24 | class SimpleConvNetBlock(nn.Module): 25 | def __init__(self, in_channels, out_channels, kernel): 26 | nn.Module.__init__(self) 27 | self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel, padding=1) 28 | self.relu = nn.ReLU(inplace=True) 29 | self.maxpool = nn.MaxPool2d(2) 30 | 31 | def forward(self, x): 32 | y = self.conv(x) 33 | y = self.relu(y) 34 | y = self.maxpool(y) 35 | return y 36 | 37 | 38 | class Flatten(nn.Module): 39 | def forward(self, input): 40 | return input.view(input.size()[0], -1) 41 | 42 | 43 | class LinearWithRelu(nn.Linear): 44 | def forward(self, input): 45 | output = nn.Linear.forward(self, input) 46 | output = F.relu(output) 47 | return output 48 | 49 | 50 | class PerTask_all_fc(nn.Module): 51 | def __init__(self, in_channels, convnet_out_size, out_dim, num_modules, num_agents,): 52 | nn.Module.__init__(self) 53 | print('Per task exclusive: all fc') 54 | self.convolutions = nn.Sequential( 55 | SimpleConvNetBlock(in_channels, 32, 3), 56 | SimpleConvNetBlock(32, 32, 3), 57 | SimpleConvNetBlock(32, 32, 3), 58 | SimpleConvNetBlock(32, 32, 3), 59 | nn.BatchNorm2d(32), 60 | Flatten() 61 | ) 62 | # self._loss_layer = Loss(torch.nn.MSELoss(), CorrectClassifiedReward(), discounting=1.) 63 | self._loss_layer = Loss(torch.nn.MSELoss(), NegLossReward(), discounting=1.) 64 | self.fc_layers = Sequential( 65 | PerTaskAssignment(), 66 | Selection(*[LinearWithRelu(convnet_out_size, 48) for _ in range(num_modules)]), 67 | Selection(*[LinearWithRelu(48, 48) for _ in range(num_modules)]), 68 | Selection(*[nn.Linear(48, out_dim) for _ in range(num_modules)]), 69 | ) 70 | 71 | def forward(self, x, tasks): 72 | y = self.convolutions(x) 73 | y, meta = self.fc_layers(y, tasks=tasks) 74 | return y, meta 75 | 76 | def loss(self, yhat, ytrue, ym): 77 | return self._loss_layer(yhat, ytrue, ym) 78 | 79 | def start_logging_selections(self): 80 | for m in self.modules(): 81 | if isinstance(m, Selection): 82 | m.start_logging_selections() 83 | 84 | def stop_logging_selections_and_report(self): 85 | modules_used = '' 86 | for m in self.modules(): 87 | if isinstance(m, Selection): 88 | selections = m.stop_logging_and_get_selections() 89 | if len(selections) > 0: 90 | modules_used += '{}, '.format(len(selections)) 91 | print(' Modules used: {}'.format(modules_used)) 92 | 93 | 94 | class PerTask_1_fc(PerTask_all_fc): 95 | def __init__(self, in_channels, convnet_out_size, out_dim, num_modules, num_agents,): 96 | PerTask_all_fc.__init__(self, in_channels, convnet_out_size, out_dim, num_modules, num_agents,) 97 | print('Per task exclusive: last fc') 98 | self.convolutions = nn.Sequential( 99 | self.convolutions, 100 | Flatten(), 101 | LinearWithRelu(convnet_out_size, 48), 102 | LinearWithRelu(48, 48) 103 | ) 104 | self._loss_layer = Loss(torch.nn.MSELoss(), CorrectClassifiedReward(), discounting=1.) 105 | self.fc_layers = Sequential( 106 | PerTaskAssignment(), 107 | Selection(*[nn.Linear(48, out_dim) for _ in range(num_modules)]), 108 | ) 109 | 110 | def forward(self, x, tasks): 111 | y = self.convolutions(x) 112 | y, meta = self.fc_layers(y, tasks=tasks) 113 | return y, meta 114 | 115 | 116 | class RoutedAllFC(PerTask_all_fc): 117 | def __init__(self, decision_maker, in_channels, convnet_out_size, out_dim, num_modules, num_agents): 118 | PerTask_all_fc.__init__(self, in_channels, convnet_out_size, out_dim, num_modules, num_agents) 119 | print('Routing Networks: all fc') 120 | self._initialization = Initialization() 121 | self._per_task_assignment = PerTaskAssignment() 122 | 123 | self._decision_1 = decision_maker( 124 | num_modules, convnet_out_size, num_agents=num_agents, policy_storage_type='tabular', 125 | additional_reward_func=CollaborationReward(reward_ratio=0.3, num_actions=num_modules)) 126 | self._decision_2 = decision_maker( 127 | num_modules, 48, num_agents=num_agents, policy_storage_type='tabular', 128 | additional_reward_func=CollaborationReward(reward_ratio=0.3, num_actions=num_modules)) 129 | self._decision_3 = decision_maker( 130 | num_modules, 48, num_agents=num_agents, policy_storage_type='tabular', 131 | additional_reward_func=CollaborationReward(reward_ratio=0.3, num_actions=num_modules)) 132 | 133 | self._selection_1 = Selection(*[LinearWithRelu(convnet_out_size, 48) for _ in range(num_modules)]) 134 | self._selection_2 = Selection(*[LinearWithRelu(48, 48) for _ in range(num_modules)]) 135 | self._selection_3 = Selection(*[nn.Linear(48, out_dim) for _ in range(num_modules)]) 136 | # self._selection_f = Selection(*[nn.Linear(48, out_dim) for _ in range(num_modules)]) 137 | 138 | def forward(self, x, tasks): 139 | y = self.convolutions(x) 140 | y, meta, actions = self._initialization(y, tasks=tasks) 141 | y, meta, task_actions = self._per_task_assignment(y, meta, actions) 142 | y, meta, routing_actions_1 = self._decision_1(y, meta, task_actions) 143 | y, meta, _ = self._selection_1(y, meta, routing_actions_1) 144 | y, meta, routing_actions_2 = self._decision_2(y, meta, task_actions) 145 | y, meta, _ = self._selection_2(y, meta, routing_actions_2) 146 | y, meta, routing_actions_3 = self._decision_3(y, meta, task_actions) 147 | y, meta, _ = self._selection_3(y, meta, routing_actions_3) 148 | # y, meta, _ = self._selection_3(y, meta, task_actions) 149 | # y, meta, _ = self._selection_f(y, meta, routing_actions_3) 150 | return y, meta 151 | 152 | def _get_params_by_class(self, cls): 153 | params = [] 154 | for mod in self.modules(): 155 | if mod is self: 156 | continue 157 | if isinstance(mod, cls): 158 | params += list(mod.parameters()) 159 | return params 160 | 161 | def routing_parameters(self): 162 | return self._get_params_by_class(Decision) 163 | 164 | def module_parameters(self): 165 | params = self._get_params_by_class(Selection) 166 | params += list(self.convolutions.parameters()) 167 | return params 168 | 169 | 170 | class Dispatched(RoutedAllFC): 171 | def __init__(self, decision_maker, in_channels, convnet_out_size, out_dim, num_modules, num_agents): 172 | RoutedAllFC.__init__(self, decision_maker, in_channels, convnet_out_size, out_dim, num_modules, num_agents) 173 | self._per_task_assignment = decision_maker( 174 | num_agents, convnet_out_size, num_agents=1, policy_storage_type='tabular', 175 | additional_reward_func=CollaborationReward(reward_ratio=0.3, num_actions=num_modules)) 176 | 177 | 178 | class PerDecisionSingleAgent(RoutedAllFC): 179 | def __init__(self, decision_maker, in_channels, convnet_out_size, out_dim, num_modules, num_agents): 180 | RoutedAllFC.__init__(self, decision_maker, in_channels, convnet_out_size, out_dim, num_modules, num_agents) 181 | print('Routing Networks: all fc') 182 | self._initialization = Initialization() 183 | 184 | self._decision_1 = decision_maker( 185 | num_modules, convnet_out_size, num_agents=1, policy_storage_type='approx', 186 | additional_reward_func=CollaborationReward(reward_ratio=0.3, num_actions=num_modules)) 187 | self._decision_2 = decision_maker( 188 | num_modules, 48, num_agents=1, policy_storage_type='approx', 189 | additional_reward_func=CollaborationReward(reward_ratio=0.3, num_actions=num_modules)) 190 | self._decision_3 = decision_maker( 191 | num_modules, 48, num_agents=1, policy_storage_type='approx', 192 | additional_reward_func=CollaborationReward(reward_ratio=0.3, num_actions=num_modules)) 193 | 194 | def forward(self, x, tasks): 195 | y = self.convolutions(x) 196 | y, meta, actions = self._initialization(y, tasks=tasks) 197 | y, meta, routing_actions_1 = self._decision_1(y, meta, None) 198 | y, meta, _ = self._selection_1(y, meta, routing_actions_1) 199 | y, meta, routing_actions_2 = self._decision_2(y, meta, None) 200 | y, meta, _ = self._selection_2(y, meta, routing_actions_2) 201 | y, meta, routing_actions_3 = self._decision_3(y, meta, None) 202 | y, meta, _ = self._selection_3(y, meta, routing_actions_3) 203 | # y, meta, _ = self._selection_3(y, meta, task_actions) 204 | # y, meta, _ = self._selection_f(y, meta, routing_actions_3) 205 | return y, meta 206 | -------------------------------------------------------------------------------- /PytorchRouting/Examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cle-ros/RoutingNetworks/0f1fe1221c67a224a02bca6247d3c4488ede0a04/PytorchRouting/Examples/__init__.py -------------------------------------------------------------------------------- /PytorchRouting/Examples/run_experiments.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file defines some simple experiments to illustrate how Pytorch-Routing functions. 3 | """ 4 | import numpy as np 5 | import tqdm 6 | import torch 7 | from PytorchRouting.DecisionLayers import REINFORCE, QLearning, SARSA, ActorCritic, GumbelSoftmax, PerTaskAssignment, \ 8 | WPL, AAC, AdvantageLearning, RELAX, EGreedyREINFORCE, EGreedyAAC 9 | from PytorchRouting.Examples.Models import PerTask_all_fc, RoutedAllFC, PerTask_1_fc, PerDecisionSingleAgent, \ 10 | Dispatched 11 | from PytorchRouting.Examples.Datasets import CIFAR100MTL 12 | 13 | 14 | def compute_batch(model, batch): 15 | samples, labels, tasks = batch 16 | out, meta = model(samples, tasks=tasks) 17 | correct_predictions = (out.max(dim=1)[1].squeeze() == labels.squeeze()).cpu().numpy() 18 | accuracy = correct_predictions.sum() 19 | oh_labels = one_hot(labels, out.size()[-1]) 20 | module_loss, decision_loss = model.loss(out, meta, oh_labels) 21 | return module_loss, decision_loss, accuracy 22 | 23 | def one_hot(indices, width): 24 | indices = indices.squeeze().unsqueeze(1) 25 | oh = torch.zeros(indices.size()[0], width).to(indices.device) 26 | oh.scatter_(1, indices, 1) 27 | return oh 28 | 29 | 30 | def run_experiment(model, dataset, learning_rates, routing_module_learning_rate_ratio): 31 | print('Loaded dataset and constructed model. Starting Training ...') 32 | for epoch in range(50): 33 | optimizers = [] 34 | parameters = [] 35 | if epoch in learning_rates: 36 | try: 37 | optimizers.append(torch.optim.SGD(model.routing_parameters(), 38 | lr=routing_module_learning_rate_ratio*learning_rates[epoch])) 39 | optimizers.append(torch.optim.SGD(model.module_parameters(), 40 | lr=learning_rates[epoch])) 41 | parameters = model.module_parameters() + model.module_parameters() 42 | except AttributeError: 43 | optimizers.append(torch.optim.SGD(model.parameters(), lr=learning_rates[epoch])) 44 | parameters = model.parameters() 45 | train_log, test_log = np.zeros((3,)), np.zeros((3,)) 46 | train_samples_seen, test_samples_seen = 0, 0 47 | dataset.enter_train_mode() 48 | model.train() 49 | # while True: 50 | pbar = tqdm.tqdm(unit=' samples') 51 | while True: 52 | try: 53 | batch = dataset.get_batch() 54 | except StopIteration: 55 | break 56 | train_samples_seen += len(batch[0]) 57 | pbar.update(len(batch[0])) 58 | module_loss, decision_loss, accuracy = compute_batch(model, batch) 59 | (module_loss + decision_loss).backward() 60 | torch.nn.utils.clip_grad_norm_(parameters, 40., norm_type=2) 61 | for opt in optimizers: 62 | opt.step() 63 | model.zero_grad() 64 | train_log += np.array([module_loss.tolist(), decision_loss.tolist(), accuracy]) 65 | pbar.close() 66 | dataset.enter_test_mode() 67 | model.eval() 68 | model.start_logging_selections() 69 | while True: 70 | try: 71 | batch = dataset.get_batch() 72 | except StopIteration: 73 | break 74 | test_samples_seen += len(batch[0]) 75 | module_loss, decision_loss, accuracy = compute_batch(model, batch) 76 | test_log += np.array([module_loss.tolist(), decision_loss.tolist(), accuracy]) 77 | print('Epoch {} finished after {} train and {} test samples..\n' 78 | ' Training averages: Model loss: {}, Routing loss: {}, Accuracy: {}\n' 79 | ' Testing averages: Model loss: {}, Routing loss: {}, Accuracy: {}'.format( 80 | epoch + 1, train_samples_seen, test_samples_seen, 81 | *(train_log/train_samples_seen).round(3), *(test_log/test_samples_seen).round(3))) 82 | model.stop_logging_selections_and_report() 83 | 84 | 85 | if __name__ == '__main__': 86 | # MNIST 87 | # dataset = MNIST_MTL(64, data_files=['./Datasets/mnist.pkl.gz']) 88 | # model = PerTask_all_fc(1, 288, 2, dataset.num_tasks, dataset.num_tasks) 89 | # model = WPL_routed_all_fc(1, 288, 2, dataset.num_tasks, dataset.num_tasks) 90 | cuda = False 91 | # cuda = True 92 | 93 | # CIFAR 94 | dataset = CIFAR100MTL(10, data_files=['./Datasets/cifar-100-py/train', './Datasets/cifar-100-py/test'], cuda=cuda) 95 | model = RoutedAllFC(WPL, 3, 128, 5, dataset.num_tasks, dataset.num_tasks) 96 | # model = RoutedAllFC(RELAX, 3, 128, 5, dataset.num_tasks, dataset.num_tasks) 97 | # model = RoutedAllFC(EGreedyREINFORCE, 3, 128, 5, dataset.num_tasks, dataset.num_tasks) 98 | # model = RoutedAllFC(AdvantageLearning, 3, 128, 5, dataset.num_tasks, dataset.num_tasks) 99 | # model = PerDecisionSingleAgent(AdvantageLearning, 3, 128, 5, dataset.num_tasks, dataset.num_tasks) 100 | # model = Dispatched(AdvantageLearning, 3, 128, 5, dataset.num_tasks, dataset.num_tasks) 101 | 102 | learning_rates = {0: 3e-3, 5: 1e-3, 10: 3e-4} 103 | routing_module_learning_rate_ratio = 0.3 104 | if cuda: 105 | model.cuda() 106 | run_experiment(model, dataset, learning_rates, routing_module_learning_rate_ratio) 107 | 108 | ''' 109 | WPL_routed_all_fc(3, 512, 5, dataset.num_tasks, dataset.num_tasks) 110 | Training averages: Model loss: 0.427, Routing loss: 8.864, Accuracy: 0.711 111 | Testing averages: Model loss: 0.459, Routing loss: 9.446, Accuracy: 0.674 112 | ''' 113 | -------------------------------------------------------------------------------- /PytorchRouting/Helpers/MLP.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file defines class MLP. 3 | 4 | @author: Clemens Rosenbaum :: cgbr@cs.umass.edu 5 | @created: 6/7/18 6 | """ 7 | import numpy as np 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | class MLP(nn.Module): 13 | def __init__(self, input_dim, output_dim, layers, nonlin=F.relu): 14 | nn.Module.__init__(self) 15 | self._layers = nn.ModuleList() 16 | input_dim = int(np.prod(input_dim)) 17 | last_dim = input_dim 18 | self._nonlin = nonlin 19 | for hidden_layer_dim in layers: 20 | self._layers.append(nn.Linear(last_dim, hidden_layer_dim)) 21 | last_dim = hidden_layer_dim 22 | self._layers.append(nn.Linear(last_dim, output_dim)) 23 | 24 | def forward(self, arg, *args): 25 | out = arg 26 | for i in range(len(self._layers) - 1): 27 | layer = self._layers[i] 28 | evaluated = layer(out) 29 | out = self._nonlin(evaluated) 30 | out = self._layers[-1](out) 31 | return out 32 | 33 | __call__ = forward 34 | -------------------------------------------------------------------------------- /PytorchRouting/Helpers/RLSample.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file defines class RLSample. 3 | 4 | @author: Clemens Rosenbaum :: cgbr@cs.umass.edu 5 | @created: 6/8/18 6 | """ 7 | 8 | 9 | class RLSample(object): 10 | """ 11 | RLSample defines a simple struct-like class that is used to combine RL-relevant training information. 12 | (i.e. state, action, reward, next state, next action) 13 | """ 14 | 15 | def __init__( 16 | self, 17 | loss_function, 18 | state, 19 | action, 20 | reward, 21 | cum_return, 22 | prior_action, 23 | next_state, 24 | next_action 25 | ): 26 | self.loss_function = loss_function 27 | self.state = state 28 | self.action = action 29 | self.prior_action = prior_action 30 | self.reward = reward 31 | self.cum_return = cum_return 32 | self.next_state = next_state 33 | self.next_action = next_action 34 | -------------------------------------------------------------------------------- /PytorchRouting/Helpers/SampleMetaInformation.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file defines class SampleMetaInformation. 3 | 4 | @author: Clemens Rosenbaum :: cgbr@cs.umass.edu 5 | @created: 6/6/18 6 | """ 7 | from collections import defaultdict 8 | 9 | 10 | class SampleMetaInformation(object): 11 | """ 12 | Class SampleMetaInformation should be used to store metainformation for each sample. 13 | """ 14 | 15 | def __init__(self, task=None): 16 | self.task = task 17 | self.steps = [] 18 | 19 | def append(self, attr_name, obj, new_step=False): 20 | if new_step: 21 | self.steps.append({}) 22 | else: 23 | assert len(self.steps) > 0, 'initialize a new step first by calling this function with new_step=True' 24 | self.steps[-1][attr_name] = obj 25 | 26 | def finalize(self): 27 | """ 28 | This method finalizes a trajectory, by translating the stored sar tuples into attributes of this class 29 | :return: 30 | """ 31 | res = {} 32 | for step in self.steps: 33 | for key in step.keys(): 34 | res[key] = [] 35 | for i, step in enumerate(self.steps): 36 | for key in res.keys(): 37 | if key not in step: 38 | res[key].append(None) 39 | else: 40 | res[key].append(step[key]) 41 | for key, val in res.items(): 42 | setattr(self, key, val) 43 | -------------------------------------------------------------------------------- /PytorchRouting/Helpers/TorchHelpers.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class FakeFlatModuleList(nn.ModuleList): 5 | """ 6 | A little util class inherited from nn.ModuleList that returns the same object for any index requested. 7 | """ 8 | def __getitem__(self, idx): 9 | assert len(self._modules) == 1, 'Fake ModuleList with more than one module instantiated. Aborting.' 10 | if isinstance(idx, slice): 11 | raise ValueError('cannot slice into a FakeModuleList') 12 | else: 13 | return self._modules['0'] 14 | 15 | 16 | class Identity(nn.Module): 17 | def __init__(self, *args, **kwargs): 18 | nn.Module.__init__(self) 19 | 20 | def forward(self, xs): 21 | return xs 22 | -------------------------------------------------------------------------------- /PytorchRouting/Helpers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cle-ros/RoutingNetworks/0f1fe1221c67a224a02bca6247d3c4488ede0a04/PytorchRouting/Helpers/__init__.py -------------------------------------------------------------------------------- /PytorchRouting/PreFabs/RNNcells.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from PytorchRouting.CoreLayers import Selection 7 | import PytorchRouting.DecisionLayers as ptrdl 8 | from PytorchRouting.RewardFunctions.PerAction import RunningAverageCollaborationReward 9 | from PytorchRouting.RewardFunctions.PerAction.PerActionBaseReward import PerActionBaseReward 10 | from PytorchRouting.RewardFunctions.PerAction import RunningAverageCollaborationReward as RACollRew 11 | 12 | from PytorchRouting.Helpers.TorchHelpers import FakeFlatModuleList, Identity 13 | 14 | 15 | class RoutingRNNCellBase(nn.Module, metaclass=abc.ABCMeta): 16 | def __init__( 17 | self, 18 | in_features, 19 | hidden_size, 20 | num_selections, 21 | depth_routing, 22 | routing_agent, 23 | route_i2h=False, 24 | route_h2h=True, 25 | recurrent=False, 26 | nonlin=F.relu, 27 | num_agents=1, 28 | exploration=0.1, 29 | policy_storage_type='approx', 30 | detach=True, 31 | approx_hidden_dims=(), 32 | additional_reward_func=PerActionBaseReward(), 33 | **kwargs, 34 | ): 35 | nn.Module.__init__(self) 36 | self.in_features = in_features 37 | self.hidden_size = hidden_size 38 | self._projection_size = 4*hidden_size 39 | self.routing_width = num_selections 40 | self.routing_depth = depth_routing 41 | self.recurrent = recurrent 42 | self.nonlin = nonlin 43 | 44 | assert not (route_i2h and recurrent and in_features != self.hidden_size),\ 45 | 'Cannot route i2h recurrently if hidden_dim != in_features (hidden: {}, in: {})'.\ 46 | format(self.hidden_size, in_features) 47 | assert issubclass(routing_agent, ptrdl.Decision), \ 48 | 'Please pass the routing_agent as a class-object of the appropriate type. Reveiced {}'.format(routing_agent) 49 | 50 | # pre-computing the defs of the routed layers 51 | dimensionality_defs_i2h = [in_features] + [self._projection_size] * depth_routing 52 | dimensionality_defs_h2h = [self.hidden_size] + [self._projection_size] * depth_routing 53 | 54 | # instantiating the different routing types 55 | if route_i2h and route_h2h: 56 | # the decision makers 57 | self.router_i2h, self.selection_i2h = self._create_routing( 58 | routing_agent, num_agents, exploration, policy_storage_type, detach, approx_hidden_dims, 59 | additional_reward_func, dimensionality_defs_i2h) 60 | self.router_h2h, self.selection_h2h = self._create_routing( 61 | routing_agent, num_agents, exploration, policy_storage_type, detach, approx_hidden_dims, 62 | additional_reward_func, dimensionality_defs_h2h) 63 | self._route = self._route_i2h_h2h 64 | elif route_i2h: 65 | # the decision makers 66 | self.router_i2h, self.selection_i2h = self._create_routing( 67 | routing_agent, num_agents, exploration, policy_storage_type, detach, approx_hidden_dims, 68 | additional_reward_func, dimensionality_defs_i2h) 69 | self.linear_h2h = nn.Linear(self.hidden_size, self._projection_size) 70 | self._route = self._route_i2h 71 | elif route_h2h: 72 | # the decision makers 73 | self.linear_i2h = nn.Linear(self.in_features, self._projection_size) 74 | self.router_h2h, self.selection_h2h = self._create_routing( 75 | routing_agent, num_agents, exploration, policy_storage_type, detach, approx_hidden_dims, 76 | additional_reward_func, dimensionality_defs_h2h) 77 | self._route = self._route_h2h 78 | else: 79 | raise ValueError('Neither i2h nor h2h routing specified. Please use regular RNNCell instead.') 80 | 81 | self.reset_parameters() 82 | 83 | def _create_routing(self, routing_agent, num_agents, exploration, policy_storage_type, detach, approx_hidden_dims, 84 | additional_reward_func, dimensionality_defs): 85 | list_type = nn.ModuleList if not self.recurrent else FakeFlatModuleList 86 | effective_width = self.routing_width if not self.recurrent else self.routing_width + 1 # for termination action 87 | effective_depth = self.routing_depth if not self.recurrent else 1 88 | base_selection = [] if not self.recurrent else [Identity()] 89 | router = list_type([ 90 | routing_agent( 91 | num_selections=effective_width, 92 | in_features=dimensionality_defs[i], 93 | num_agents=num_agents, 94 | exploration=exploration, 95 | policy_storage_type=policy_storage_type, 96 | detach=detach, 97 | approx_hidden_dims=approx_hidden_dims, 98 | additional_reward_func=additional_reward_func 99 | ) for i in range(effective_depth) 100 | ]) 101 | selection = list_type([ 102 | Selection(*(base_selection + [ # need base selection for termination action 103 | nn.Linear(dimensionality_defs[i], dimensionality_defs[i + 1]) 104 | for _ in range(effective_width) 105 | ])) 106 | for i in range(effective_depth) 107 | ]) 108 | return router, selection 109 | 110 | def reset_parameters(self): 111 | std = 1.0 / math.sqrt(self.hidden_size) 112 | for w in self.parameters(): 113 | w.data.uniform_(-std, std) 114 | 115 | def _init_hidden(self, input_): 116 | h = input_.new_zeros((input_.size(0), self.args.d_hidden)) 117 | c = input_.new_zeros((input_.size(0), self.args.d_hidden)) 118 | # h = torch.zeros_like(input_) 119 | # c = torch.zeros_like(input_) 120 | return h, c 121 | 122 | @abc.abstractmethod 123 | def forward(self, x, hidden, metas, task_actions=None, mask=None): pass 124 | 125 | def _route(self, x, h, c, metas, task_actions, mask): 126 | return torch.Tensor(), torch.Tensor(), [] 127 | 128 | def _route_internals(self, input, metas, decisions, selections, task_actions=None, mask=None): 129 | batch_size = len(metas) 130 | mask = torch.ones(batch_size, dtype=torch.uint8, device=input.device) if mask is None else mask 131 | for i in range(self.routing_depth): 132 | if not any(mask): 133 | break 134 | input, metas, actions = decisions[i](input, metas, prior_actions=task_actions, mask=mask) 135 | mask *= (1 - (actions.squeeze() == 0)) 136 | input, metas, _ = selections[i](input, metas, actions, mask=mask) 137 | if i < (self.routing_depth - 1): 138 | input = self.nonlin(input) 139 | return input, metas 140 | 141 | def _route_i2h(self, x, h, c, metas, task_actions, mask): 142 | i2h, metas = self._route_internals(x, metas, self.router_i2h, self.selection_i2h, task_actions, mask) 143 | h2h = self.linear_h2h(h) 144 | return i2h, h2h, metas 145 | 146 | def _route_h2h(self, x, h, c, metas, task_actions, mask): 147 | i2h = self.linear_i2h(x) 148 | h2h, metas = self._route_internals(h, metas, self.router_h2h, self.selection_h2h, task_actions, mask) 149 | return i2h, h2h, metas 150 | 151 | def _route_i2h_h2h(self, x, h, c, metas, task_actions, mask): 152 | i2h, metas = self._route_internals(x, metas, self.router_i2h, self.selection_i2h, task_actions, mask) 153 | h2h, metas = self._route_internals(h, metas, self.router_h2h, self.selection_h2h, task_actions, mask) 154 | return i2h, h2h, metas 155 | 156 | 157 | class RoutingLSTMCell(RoutingRNNCellBase): 158 | def forward(self, x, hidden, metas, task_actions=None, mask=None): 159 | if hidden is None: 160 | hidden = self._init_hidden(x) 161 | h, c = hidden 162 | 163 | # Linear mappings 164 | i2h_x, h2h_x, metas = self._route(x, h, c, metas, task_actions, mask) 165 | preact = i2h_x + h2h_x 166 | 167 | # activations 168 | gates = preact[:, :3 * self.hidden_size].sigmoid() 169 | c_hat = preact[:, 3 * self.hidden_size:].tanh() # input gating 170 | i_t = gates[:, :self.hidden_size] # input 171 | f_t = gates[:, self.hidden_size:2 * self.hidden_size] # forgetting 172 | o_t = gates[:, -self.hidden_size:] # output 173 | 174 | c_t = torch.mul(c, f_t) + torch.mul(i_t, c_hat) 175 | h_t = torch.mul(o_t, c_t.tanh()) 176 | return (h_t, c_t), metas 177 | 178 | 179 | class RoutingGRUCell(RoutingRNNCellBase): 180 | def forward(self, x, hidden, metas): 181 | raise NotImplementedError() 182 | -------------------------------------------------------------------------------- /PytorchRouting/PreFabs/__init__.py: -------------------------------------------------------------------------------- 1 | from .RNNcells import RoutingLSTMCell 2 | -------------------------------------------------------------------------------- /PytorchRouting/RewardFunctions/Final/BaseReward.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file defines class BaseReward. 3 | 4 | @author: Clemens Rosenbaum :: cgbr@cs.umass.edu 5 | @created: 6/8/18 6 | """ 7 | import abc 8 | import torch.nn as nn 9 | 10 | 11 | class BaseReward(nn.Module, metaclass=abc.ABCMeta): 12 | """ 13 | Class BaseReward defines the base function for all final reward functions. 14 | """ 15 | 16 | def __init__(self, scale=1.): 17 | nn.Module.__init__(self) 18 | self._scale = scale 19 | 20 | @abc.abstractmethod 21 | def forward(self, loss, yest, ytrue): pass 22 | -------------------------------------------------------------------------------- /PytorchRouting/RewardFunctions/Final/CorrectClassifiedReward.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file defines class CorrectClassifiedReward. 3 | 4 | @author: Clemens Rosenbaum :: cgbr@cs.umass.edu 5 | @created: 6/8/18 6 | """ 7 | from .BaseReward import BaseReward 8 | 9 | 10 | class CorrectClassifiedReward(BaseReward): 11 | """ 12 | Class CorrectClassifiedReward defines the +1 reward for correct classification, and -1 otherwise. 13 | """ 14 | 15 | def __init__(self, *args, **kwargs): 16 | BaseReward.__init__(self, *args, **kwargs) 17 | 18 | def forward(self, loss, yest, ytrue): 19 | # input checking - onehot vs indices 20 | if yest.numel() == yest.size(0): 21 | y_ind = yest 22 | else: 23 | _, y_ind = yest.max(dim=1) 24 | if ytrue.numel() == ytrue.size(0): 25 | yt_ind = ytrue 26 | else: 27 | _, yt_ind = ytrue.max(dim=1) 28 | return -1. + 2. * (y_ind.squeeze() == yt_ind.squeeze()).float() -------------------------------------------------------------------------------- /PytorchRouting/RewardFunctions/Final/NegLossReward.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file defines class NegLossReward. 3 | 4 | @author: Clemens Rosenbaum :: cgbr@cs.umass.edu 5 | @created: 6/8/18 6 | """ 7 | import torch 8 | from .BaseReward import BaseReward 9 | 10 | 11 | class NegLossReward(BaseReward): 12 | """ 13 | Class NegLossReward defines the simplest reward function, expressed as the negative loss. 14 | """ 15 | 16 | def __init__(self, *args, **kwargs): 17 | BaseReward.__init__(self, *args, **kwargs) 18 | 19 | def forward(self, loss, yest, ytrue): 20 | with torch.no_grad(): 21 | reward = -loss.squeeze() 22 | return reward -------------------------------------------------------------------------------- /PytorchRouting/RewardFunctions/Final/__init__.py: -------------------------------------------------------------------------------- 1 | from .NegLossReward import NegLossReward 2 | from .CorrectClassifiedReward import CorrectClassifiedReward -------------------------------------------------------------------------------- /PytorchRouting/RewardFunctions/PerAction/CollaborationReward.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file defines class CollaborationReward. 3 | 4 | @author: Clemens Rosenbaum :: cgbr@cs.umass.edu 5 | @created: 6/8/18 6 | """ 7 | import torch 8 | 9 | from .PerActionBaseReward import PerActionBaseReward 10 | 11 | 12 | class CollaborationReward(PerActionBaseReward): 13 | """ 14 | Class CollaborationReward defines a collaboration reward measured by the average probability 15 | of taking the action taken by an agent. 16 | """ 17 | 18 | def __init__(self, reward_ratio=0.1, num_actions=None, history_len=256): 19 | PerActionBaseReward.__init__(self, history_len) 20 | self._reward_ratio = reward_ratio 21 | self._num_actions = num_actions 22 | 23 | def get_reward(self, dist, action): 24 | action_count = torch.zeros(len(self._actions), self._num_actions).to(dist.device) 25 | action_count = action_count.scatter(1, torch.stack(list(self._actions), 0).unsqueeze(1), 1.) 26 | action_count = torch.sum(action_count, dim=0)/len(self._actions) 27 | self._precomp = action_count 28 | self._precomp = self._reward_ratio * self._precomp 29 | return self._precomp[action] * self._reward_ratio 30 | 31 | 32 | class RunningAverageCollaborationReward(PerActionBaseReward): 33 | """ 34 | Provides the same functionality as CollaborationReward, but with a much faster computing running average. 35 | """ 36 | 37 | def __init__(self, reward_ratio=0.1, num_actions=None, history_len=256): 38 | PerActionBaseReward.__init__(self, history_len) 39 | self._reward_ratio = reward_ratio 40 | self._num_actions = num_actions 41 | self._adaptation_rate = 10 ** (-1./history_len) 42 | self._dists = None 43 | self._actions = None 44 | self._precomp = None 45 | 46 | def register(self, dist, action): 47 | # initializing 48 | if self._actions is None: 49 | self._actions = torch.zeros(self._num_actions).to(dist.device) 50 | # one hot encoding 51 | action_oh = torch.zeros_like(self._actions).float() 52 | action_oh[action.item()] = 1. 53 | # running average learning 54 | self._actions = self._adaptation_rate * self._actions + (1 - self._adaptation_rate) * action_oh 55 | # normalizing 56 | self._actions = self._actions / self._actions.sum() 57 | 58 | def clear(self): 59 | self._dists = None 60 | self._actions = None 61 | self._precomp = None 62 | 63 | def get_reward(self, dist, action): 64 | return self._actions[action] * self._reward_ratio 65 | -------------------------------------------------------------------------------- /PytorchRouting/RewardFunctions/PerAction/ManualReward.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file defines class ManualReward. 3 | 4 | @author: Clemens Rosenbaum :: cgbr@cs.umass.edu 5 | @created: 12/10/18 6 | """ 7 | import torch 8 | from .PerActionBaseReward import PerActionBaseReward 9 | 10 | 11 | class ManualReward(PerActionBaseReward): 12 | """ 13 | Class ManualReward defines ... 14 | """ 15 | 16 | def __init__(self, rewards, num_actions=None): 17 | PerActionBaseReward.__init__(self) 18 | if num_actions is not None: 19 | assert len(rewards) == num_actions 20 | self._rewards = torch.FloatTensor(rewards).squeeze() 21 | self._num_actions = num_actions 22 | 23 | def get_reward(self, dist, action): 24 | return self._rewards[action].to(action.device) 25 | -------------------------------------------------------------------------------- /PytorchRouting/RewardFunctions/PerAction/PerActionBaseReward.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file defines class BaseReward. 3 | 4 | @author: Clemens Rosenbaum :: cgbr@cs.umass.edu 5 | @created: 6/8/18 6 | """ 7 | from collections import deque 8 | import abc 9 | import torch 10 | 11 | 12 | class PerActionBaseReward(object, metaclass=abc.ABCMeta): 13 | """ 14 | Class BaseReward defines the base class for per-action rewards. 15 | """ 16 | 17 | def __init__(self, history_window=256, *args, **kwargs): 18 | self._hist_len = history_window 19 | self._dists = deque(maxlen=history_window) 20 | self._actions = deque(maxlen=history_window) 21 | self._precomp = None 22 | 23 | def register(self, dist, action): 24 | self._dists.append(dist.detach()) 25 | self._actions.append(action.detach()) 26 | 27 | def clear(self): 28 | self._dists = deque(maxlen=self._hist_len) 29 | self._actions = deque(maxlen=self._hist_len) 30 | self._precomp = None 31 | 32 | def get_reward(self, dist, action): 33 | return torch.FloatTensor([0.]).to(action.device) 34 | -------------------------------------------------------------------------------- /PytorchRouting/RewardFunctions/PerAction/__init__.py: -------------------------------------------------------------------------------- 1 | from .CollaborationReward import CollaborationReward, RunningAverageCollaborationReward 2 | from .ManualReward import ManualReward 3 | -------------------------------------------------------------------------------- /PytorchRouting/RewardFunctions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cle-ros/RoutingNetworks/0f1fe1221c67a224a02bca6247d3c4488ede0a04/PytorchRouting/RewardFunctions/__init__.py -------------------------------------------------------------------------------- /PytorchRouting/UtilLayers/Sequential.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file defines class RoutingSequential. 3 | 4 | @author: Clemens Rosenbaum :: cgbr@cs.umass.edu 5 | @created: 6/13/18 6 | """ 7 | from collections import OrderedDict 8 | import torch.nn as nn 9 | 10 | from PytorchRouting.CoreLayers.Initialization import Initialization 11 | from PytorchRouting.CoreLayers.Selection import Selection 12 | from PytorchRouting.DecisionLayers.Decision import Decision 13 | 14 | 15 | class Sequential(nn.Sequential): 16 | """ 17 | Sequential is a routing wrapper around the original torch.nn.Sequential class. 18 | It includes the "Initialization" layer and handles the routing triplet (y, meta, actions) sequentially. 19 | As a consequence, it cannot handle cases where actions are not immediately consumed, but have to be 20 | handled repeatedly (dispatching). 21 | """ 22 | 23 | def __init__(self, *args): 24 | additional_modules = OrderedDict([('initialization', Initialization())]) 25 | if isinstance(args, OrderedDict): 26 | args = additional_modules.update(args) 27 | else: 28 | args = list(additional_modules.values()) + list(args) 29 | nn.Sequential.__init__(self, *args) 30 | 31 | def forward(self, x, tasks=()): 32 | """ 33 | As the class Sequential includes an initialization layer, forward only takes a batch of input, and a list of 34 | tasks. 35 | :param x: samples. the first dim has to be the batch dimension 36 | :param tasks: a list/tuple/iterable of integer task labels 37 | :return: 38 | """ 39 | initialization_module = self._modules[list(self._modules.keys())[0]] 40 | ys, meta, actions = initialization_module(x, tasks=tasks) 41 | for name, mod in list(self._modules.items())[1:]: 42 | if isinstance(mod, Selection) or isinstance(mod, Decision): 43 | ys, meta, actions = mod(ys, meta, actions) 44 | elif isinstance(mod, nn.Module): 45 | ys = mod(ys) 46 | else: 47 | raise ValueError('Sequential can only be initialized with nn.Modules.') 48 | return ys, meta 49 | -------------------------------------------------------------------------------- /PytorchRouting/UtilLayers/__init__.py: -------------------------------------------------------------------------------- 1 | from .Sequential import Sequential -------------------------------------------------------------------------------- /PytorchRouting/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cle-ros/RoutingNetworks/0f1fe1221c67a224a02bca6247d3c4488ede0a04/PytorchRouting/__init__.py -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pytorch-Routing 2 | Pytorch-Routing is a pytorch-based implementation of 'RoutingNetworks' for Python 3.5+. The best overview over the work can probably be found in: 3 | 4 | Clemens Rosenbaum, Ignacio Cases, Matthew Riemer, Tim Klinger - _Routing Networks and the Challenges of Modular and Compositional Computation_ (arxiv) 5 | 6 | https://arxiv.org/abs/1904.12774 7 | 8 | The idea was originally published in ICLR: 9 | 10 | Clemens Rosenbaum, Tim Klinger, Matthew Riemer - _Routing Networks: Adaptive Selection of Non-Linear Functions for Multi-Task Learning_ (ICLR 2018). 11 | 12 | https://openreview.net/forum?id=ry8dvM-R- 13 | 14 | An extension to language domains was introduced in NAACL: 15 | 16 | Ignacio Cases, Clemens Rosenbaum, Matthew Riemer, Atticus Geiger, Tim Klinger, Alex Tamkin, Olivia Li, Sandhini Agarwal, Joshua D. Greene, Dan Jurafsky, Christopher Potts and Lauri Karttunen "Recursive Routing Networks: Learning to Compose Modules for Language Understanding" (NAACL 2019). 17 | 18 | https://www.aclweb.org/anthology/N19-1365 19 | 20 | The latest research on "dispatched" routing networks for single task learning can be found here: 21 | 22 | Clemens Rosenbaum, Ignacio Cases, Matthew Riemer, Atticus Geiger, Lauri Karttunen, Joshua D. Greene, Dan Jurafsky, Christopher Potts "Dispatched Routing Networks" (Stanford Tech Report 2019). 23 | 24 | https://nlp.stanford.edu/projects/sci/dispatcher.pdf 25 | 26 | ### What's new 27 | I added implementations of several different new decision making algorithms. In particular, I added reparameterization techniques such as Gumbel/Concrete and RELAX. Additionally, I added some Advantage based RL techniques. 28 | 29 | I also added a new module called "prefabs" that includes already defined more or less standard routed layers. For now, it only contains an RNN prefab in form of a routed LSTM. Routing for both i2h and h2h layers can be specified at initialization. 30 | 31 | ## Implementation 32 | This package provides an implementation of RoutingNetworks that tries to integrate with Pytorch (https://pytorch.org/) as smoothly as possible by providing RoutingNetwork "layers", each implemented as a `nn.Module`. 33 | 34 | (To jump the explanations and go to the examples, see [here](EXAMPLES.md)). 35 | 36 | The basic functionality of routing is provided by four different kind of modules: 37 | 38 | ### PytorchRouting.CoreLayers.Initialization.Initialization 39 | As Routing Networks need to track 'meta-information', we need to work around some Pytorch restrictions by extending what a layer takes as an argument and what it returns. This meta-information consists of the trajectories necessary to later on train Reinforcement Learning based routers, and of the actions used for decisions. Consequently, the information passed from one Pytorch-Routing layer to the next is a triplet of the form `(batch, meta_info_list, actions)`. 40 | 41 | The initialization of the meta-information objects - one for each sample in a batch - is thus the first required step when using this package, and is achieved with the `Initialization` module. 42 | ```Python 43 | init = Initialization() 44 | batch, meta_list, actions = init(batch, tasks=()) 45 | ``` 46 | The initialization module takes the batch - in form of a Pytorch `Variable` (with the first dim as the batch dim) and an optional list of task-labels (for multi-task learning) and returns the required triplet `(batch, meta_info_list, actions)` (though with empty actions). 47 | 48 | ### PytorchRouting.Decision.* 49 | The next step in routing a network is to make a routing decision (i.e. creating a selection) for each sample. These layers - with one class for each decision making technique - take the Pytorch-Routing triplet, and make a decision for each sample in the batch. These decisions are logged in the meta-information objects, and returned as a `torch.LongTensor` as the third element of the Pytorch-Routing triplet: 50 | 51 | ```Python 52 | decision = Decision( 53 | num_selections, 54 | in_features, 55 | num_agents=1, 56 | exploration=0.1, 57 | policy_storage_type='approx', 58 | detach=True, 59 | approx_hidden_dims=(), 60 | approx_module=None, 61 | additional_reward_class=PerActionBaseReward, 62 | additional_reward_args={}) 63 | batch, meta_list, new_actions = decision(batch, meta_list, actions) 64 | ``` 65 | The constructor arcuments are as follows: `num_selections` defines the number of selections available in the next routed layer; `in_features` defines the dimensionality of one sample when passed into this layer (required to construct function approximators for policies); `num_agents` defines the number of agents available at this layer; `exploration` defines the exploration rate for agents that support it; `policy_storage_type` refers to how the agents' policies are stored, and can be either `approx` or `tabular`; `detach` is a bool and refers to whether or not the gradient flow is cut when passed into the agents's approximators; `approx_hidden_dims` defines the hidden layer dimensions if the agents construct their default policy approximator, an MLP; `approx_module` overrides all other approximator settings, and takes an already instantiated policy approximation module for its agents (which are not limited to MLPs); `additional_reward_function` takes as argument an instance of type `PytorchRouting.RewardFunctions.PerAction.*` and that specifies how per-action rewards should be calculated by the agents. As this reward design may vary per layer, it has to be located here, and not in the final loss function as the other rewards are (see below). 66 | 67 | #### _Dispatching_ 68 | The `actions` argument to the layer call will be interpreted as the dispatcher actions specifying the agents to be selected: 69 | ```Python 70 | # 1. getting the dispatcher actions 71 | batch, meta_list, dispatcher_actions = decision_dispatcher(batch, meta_list, []) 72 | # 2. passing the dispatcher actions to an agent 73 | batch, meta_list, selection_actions = decision_selector(batch, meta_list, dispatcher_actions) 74 | # 3. selecting the modules (see below) 75 | ``` 76 | Using a special decision module, this can also be used to implement per-task agents: 77 | ```Python 78 | # 1. getting the per-task assignment actions 79 | batch, meta_list, per_task_actions = PytorchRouting.DecisionLayers.Others.PerTaskAssignment()(batch, meta_list, []) 80 | # 2. passing the task assignment preselections 81 | batch, meta_list, selection_actions = decision_selector(batch, meta_list, per_task_actions) 82 | # 3. selecting the modules (see below) 83 | ``` 84 | ### PytorchRouting.CoreLayers.Selection 85 | Now that the actions have been computed, the actual selection of the function block is the next step. This functionality is provided by the `Selection` module: 86 | ```Python 87 | selection = Selection(*modules) 88 | batch_out, meta_list, actions = selection(batch, meta_list, actions) 89 | ``` 90 | Once the module has been initialized by passing in a list of initialized modules, it's application is straightforward. An example of how to initialize the selection layer can look as follows: 91 | ```Python 92 | # for 5 different fully connected layers with the same number of parameters 93 | selection = Selection(*[nn.Linear(in_dim, out_dim) for _ in range(5)]) 94 | # for 2 different MLP's, with different number of parameters. 95 | selection = Selection(MLP(in_dim, out_dim, hidden=(64, 128)), MLP(in_dim, out_dim, hidden=(64, 64))) 96 | ``` 97 | 98 | ### PytorchRouting.CoreLayers.Loss 99 | The final function is a Pytorch-Routing specific loss module. This is required as the loss from the normal training needs to be translated (per-sample) to a Reinforcement Learning reward signal: 100 | ```Python 101 | loss_func = Loss(pytorch_loss_func, routing_reward_func) 102 | module_loss, routing_loss = loss_func(batch_estimates, batch_true, meta_list) 103 | ``` 104 | The loss module is instantiated by passing in two different other modules - a pytorch loss function (i.e. a `nn.*Loss*` module) and a reward function (from `PytorchRouting.RewardFunctions.Final.*`) to translate to a reward. Once instantiated, it takes different arguments than the other "layer-like" modules of Pytorch-Routing. These arguments are the batch estimates, i.e. the first output of the routing-triplet, the true targets and the meta-list, i.e. the second output of the routing-triplet. An example could be: 105 | ```Python 106 | loss_func = Loss(torch.nn.CrossEntropyLoss(), NegLossReward()) 107 | module_loss, routing_loss = loss_func(batch_estimates, batch_true, meta_list) 108 | ``` 109 | To train, we can then simply use backprop on the loss and take an optimization step: 110 | ```Python 111 | total_loss = module_loss + routing_loss 112 | total_loss.backward() 113 | opt.step() 114 | ``` 115 | Additionally, the code allows to have different learning rates for different components - such as for the decision-making networks - using pure Pytorch logic: 116 | ```Python 117 | opt_decision = optim.SGD(decision_module.parameters(), lr=decision_learning_rate) 118 | opt_module = optim.SGD([... all other parameters ...], lr=module_learning_rate) 119 | ``` 120 | ## Examples 121 | See [here](EXAMPLES.md). 122 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name='PytorchRouting', 5 | version='0.4.1', 6 | python_requires='>=3.6', 7 | packages=['PytorchRouting', 'PytorchRouting.Helpers', 'PytorchRouting.Examples', 'PytorchRouting.CoreLayers', 8 | 'PytorchRouting.UtilLayers', 'PytorchRouting.DecisionLayers', 'PytorchRouting.DecisionLayers.Others', 9 | 'PytorchRouting.DecisionLayers.ReinforcementLearning', 'PytorchRouting.RewardFunctions', 10 | 'PytorchRouting.RewardFunctions.Final', 'PytorchRouting.RewardFunctions.PerAction'], 11 | url='https://github.com/cle-ros/RoutingNetworks', 12 | install_requires=['torch>=1.0', 'numpy>=1.12'], 13 | license='Apache', 14 | author='Clemens Rosenbaum', 15 | author_email='cgbr@cs.umass.edu', 16 | description='a pytorch-based implementation of "RoutingNetworks"' 17 | ) 18 | --------------------------------------------------------------------------------