├── src ├── __init__.py └── polytropon │ ├── __init__.py │ ├── utils.py │ ├── polytropon.py │ └── adapters.py ├── media └── paper.pdf ├── setup.py ├── LICENSE ├── README.md └── .gitignore /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /media/paper.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/McGill-NLP/polytropon/HEAD/media/paper.pdf -------------------------------------------------------------------------------- /src/polytropon/__init__.py: -------------------------------------------------------------------------------- 1 | from . import adapters, utils 2 | from .polytropon import VARIANT2CLASS, SkilledMixin -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | import pathlib 3 | 4 | here = pathlib.Path(__file__).parent.resolve() 5 | 6 | # Get the long description from the README file 7 | long_description = (here / 'README.md').read_text(encoding='utf-8') 8 | 9 | setup( 10 | name='polytropon', 11 | version='0.0.1', 12 | description='Modular Transformers for multitask learning', 13 | license="MIT", 14 | long_description=long_description, 15 | url='https://github.com/McGill-NLP/polytropon', 16 | author='Edoardo Maria Ponti', 17 | author_email='edoardo-maria.ponti@mila.quebec', 18 | package_dir={'': 'src'}, 19 | packages=find_packages(where='src'), 20 | 21 | install_requires=[ 22 | "typing", 23 | "scipy", 24 | "torch", 25 | "transformers", 26 | ], 27 | python_requires='>=3.7', 28 | ) 29 | -------------------------------------------------------------------------------- /src/polytropon/utils.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | ATTENTION_LINEARS = ["k_proj", "v_proj", "q_proj", "out_proj", "k", "v", "q", "o"] 5 | 6 | 7 | def replace_layers(model, adapter_class, n_tasks, n_skills, skills, only_attention=True): 8 | for name, module in model.named_children(): 9 | if len(list(module.children())) > 0: 10 | replace_layers(module, adapter_class, n_tasks, n_skills, skills, only_attention=only_attention) 11 | 12 | if isinstance(module, nn.Linear) and (name in ATTENTION_LINEARS or not only_attention): 13 | new_linear = adapter_class(n_tasks, n_skills, skills, module.weight, module.bias) 14 | setattr(model, name, new_linear) 15 | 16 | 17 | def inform_layers(model, adapter_class, value): 18 | for module in model.children(): 19 | if len(list(module.children())) > 0: 20 | inform_layers(module, adapter_class, value) 21 | 22 | if isinstance(module, adapter_class): 23 | module.task_ids = value 24 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 McGill-NLP 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Polytropon: Combining Modular Skills in Multitask Learning 2 | 3 | [**Updates**](#updates) | [**Installation**](#installation) | [**Usage**](#usage) | [**Cite**](#cite) | [**Paper**](media/paper.pdf) 4 | 5 | ## Updates 6 | 7 | 23/05/03: *New repository!* The current repository is outdated. We recommend to use instead the full implementation of Polytropon at https://github.com/microsoft/mttl 8 | 9 | ## Installation 10 | 11 | ```python 12 | pip install git+https://github.com/McGill-NLP/polytropon 13 | ``` 14 | 15 | Otherwise, if you wish to clone the repo: 16 | 17 | ```python 18 | git clone https://github.com/McGill-NLP/polytropon.git 19 | cd polytropon 20 | pip install -e . 21 | ``` 22 | 23 | ## Usage 24 | 25 | ```python 26 | from polytropon import SkilledMixin 27 | 28 | # load any pretrained model from transformers 29 | from transformers import T5ForConditionalGeneration 30 | model = T5ForConditionalGeneration.from_pretrained("t5-small") 31 | 32 | # merge it with polytropon 33 | model = SkilledMixin( 34 | model, 35 | n_tasks, 36 | n_skills, 37 | ) 38 | ``` 39 | 40 | ## Cite 41 | 42 | ``` 43 | @misc{ponti2022combining, 44 | title={Combining Modular Skills in Multitask Learning}, 45 | author={Edoardo M. Ponti and Alessandro Sordoni and Yoshua Bengio and Siva Reddy}, 46 | year={2022}, 47 | eprint={2202.13914}, 48 | archivePrefix={arXiv}, 49 | primaryClass={cs.LG} 50 | } 51 | ``` 52 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /src/polytropon/polytropon.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | from scipy import special 4 | 5 | import torch 6 | from torch import nn 7 | 8 | from .adapters import ( 9 | HyperLoRALinear, 10 | SkilledLoRALinear, 11 | SkilledLTSFTLinear, 12 | ) 13 | from .utils import replace_layers, inform_layers 14 | 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | VARIANT2CLASS = { 19 | "hyperformer": (HyperLoRALinear, True), 20 | "sparse": (SkilledLTSFTLinear, False), 21 | } 22 | 23 | 24 | class SkilledMixin(nn.Module): 25 | def __init__( 26 | self, 27 | model: nn.Module, 28 | n_tasks: int, 29 | n_skills: int, 30 | skilled_variant: str = "learned", 31 | freeze: bool = True, 32 | custom_skills: str = None, 33 | state_dict = None, 34 | ): 35 | super().__init__() 36 | self.model = model 37 | self.n_tasks = n_tasks 38 | self.n_skills = n_skills 39 | self.skilled_variant = skilled_variant 40 | 41 | if freeze: 42 | for p in self.model.parameters(): 43 | p.requires_grad = False 44 | 45 | adapter_class, only_attention = VARIANT2CLASS.get(skilled_variant, (SkilledLoRALinear, True)) 46 | self.adapter_class = adapter_class 47 | skills = self.get_skills(custom_skills) 48 | replace_layers(self.model, adapter_class, n_tasks, n_skills, skills, only_attention=only_attention) 49 | 50 | if state_dict is not None: 51 | self.model.load_state_dict(state_dict, strict=False) 52 | self.model.tie_weights() 53 | 54 | def get_skills(self, custom_skills): 55 | if self.skilled_variant in ["learned", "hyper", "sparse"]: 56 | # skills are computed inside each module 57 | skills = None 58 | elif self.skilled_variant == "shared": 59 | skills = torch.ones((self.n_tasks, 1), device=task_ids.device) 60 | elif self.skilled_variant == "private": 61 | skills = torch.eye(self.n_tasks, self.n_tasks, device=task_ids.device) 62 | elif self.skilled_variant == "custom": 63 | skills = custom_skills 64 | else: 65 | raise ValueError 66 | 67 | return skills 68 | 69 | def generate(self, task_ids, *args, **kwargs): 70 | inform_layers(self.model, self.adapter_class, task_ids) 71 | return self.model.generate(*args, **kwargs) 72 | 73 | def forward(self, task_ids, *args, add_prior=False, **kwargs): 74 | inform_layers(self.model, self.adapter_class, task_ids) 75 | outputs = self.model.forward(*args, **kwargs) 76 | 77 | if self.training and self.skilled_variant == "learned" and add_prior: 78 | aux_loss = [self.neg_log_IBP(p) for n, p in self.model.named_parameters() if "skill_logits" in n] 79 | outputs.loss += torch.stack(aux_loss).sum() 80 | 81 | return outputs 82 | 83 | @staticmethod 84 | def log_factorial(value): 85 | return torch.lgamma(value + 1) 86 | 87 | def neg_log_IBP(self, matrix): 88 | """ Calculate IBP prior contribution - log P(Z) 89 | Based on https://github.com/davidandrzej/PyIBP/blob/master/PyIBP.py """ 90 | 91 | # discretise 92 | N, K = matrix.shape 93 | matrix = torch.sigmoid(matrix) 94 | matrix_hard = (matrix > .5).float() 95 | Z = matrix_hard - matrix.detach() + matrix 96 | 97 | # penalise non-unique histories (columns of Z) 98 | _, Khs = Z.unique(dim=1, return_counts=True) 99 | logp = - self.log_factorial(Khs).sum() 100 | 101 | # total feature usage 102 | m = Z.sum(dim=0) 103 | m = m[m.nonzero()].squeeze() 104 | logp += (self.log_factorial(N - m) + self.log_factorial(m - 1)).sum() 105 | 106 | return - logp 107 | 108 | 109 | if __name__ == "__main__": 110 | from transformers import T5Tokenizer, T5ForConditionalGeneration 111 | tokenizer = T5Tokenizer.from_pretrained("t5-small") 112 | model = T5ForConditionalGeneration.from_pretrained("t5-small") 113 | inputs = ["Tell me, oh Muse, of that ingenious hero who travelled far and wide after he had sacked the famous town of Troy.", 114 | "Many cities did he visit, and many were the nations with whose manners and customs he was acquainted."] 115 | inputs = tokenizer(inputs, return_tensors="pt", padding=True) 116 | task_ids = torch.LongTensor([0, 1]) 117 | 118 | for skilled_variant in ["learned", "hyper", "sparse", "shared", "private"]: 119 | skilled_model = SkilledMixin(model, n_tasks=2, n_skills=2, skilled_variant=skilled_variant) 120 | logger.warning("forward %s: %s", skilled_variant, skilled_model.forward(task_ids, labels=inputs["input_ids"], add_prior=True, **inputs)) 121 | logger.warning("generate %s: %s", skilled_variant, skilled_model.generate(task_ids, **inputs)) 122 | -------------------------------------------------------------------------------- /src/polytropon/adapters.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import itertools 4 | from typing import Optional 5 | 6 | import torch 7 | from torch import Tensor 8 | from torch import nn 9 | import torch.nn.functional as F 10 | from torch.nn.init import calculate_gain 11 | from torch.distributions.relaxed_bernoulli import RelaxedBernoulli 12 | 13 | 14 | EPS = 1e-12 15 | 16 | 17 | class SkilledModule(nn.Module): 18 | def __init__(self): 19 | super().__init__() 20 | self._task_ids = None 21 | 22 | @property 23 | def task_ids(self): 24 | return self._task_ids 25 | 26 | @task_ids.setter 27 | def task_ids(self, value): 28 | self._task_ids = value 29 | 30 | 31 | class HyperLoRALinear(SkilledModule): 32 | """ Applies a linear function parameterised by a base bias 33 | and a weighted average of base and task-conditioned weights 34 | """ 35 | __constants__ = ['in_features', 'out_features'] 36 | in_features: int 37 | out_features: int 38 | weight: Tensor 39 | 40 | def __init__(self, 41 | n_tasks: int, 42 | n_skills: int, 43 | skills: Optional[Tensor], 44 | weight: Tensor, 45 | bias: Optional[Tensor], 46 | r: int = 16, 47 | freeze: bool = True 48 | ) -> None: 49 | super().__init__() 50 | self.out_features, self.in_features = weight.shape 51 | self.r = r 52 | 53 | self.task_embs = nn.Embedding(n_tasks, n_skills) 54 | self.task_proj = nn.Sequential( 55 | nn.Linear(n_skills, n_skills), 56 | nn.ReLU(), 57 | nn.Linear(n_skills, n_skills), 58 | ) 59 | 60 | self.weight = nn.Parameter(weight.data) 61 | self.weight.requires_grad = not freeze 62 | 63 | self.hyper_weight_A = nn.Linear(n_skills, r * self.in_features, bias=False) 64 | self.hyper_weight_B = nn.Linear(n_skills, self.out_features * r, bias=False) 65 | self.scaling = 1 / self.r 66 | 67 | if bias is not None: 68 | self.bias = nn.Parameter(bias.data) 69 | self.bias.requires_grad = not freeze 70 | else: 71 | self.register_parameter('bias', None) 72 | 73 | self.reset_parameters() 74 | 75 | def reset_parameters(self): 76 | torch.nn.init.zeros_(self.hyper_weight_B.weight) 77 | 78 | def forward(self, input: Tensor) -> Tensor: 79 | # Provisions for inputs repeated for generation 80 | assert input.size()[0] % self.task_ids.size(0) == 0 81 | repeats = input.size()[0] // self.task_ids.size(0) 82 | if repeats > 1: 83 | self.task_ids = torch.repeat_interleave(self.task_ids, repeats, dim=0) 84 | 85 | task_embs = self.task_embs(self.task_ids) 86 | task_embs = self.task_proj(task_embs) 87 | 88 | hyper_weight_A = self.hyper_weight_A(task_embs).view(input.size()[0], self.in_features, self.r) 89 | hyper_weight_B = self.hyper_weight_B(task_embs).view(input.size()[0], self.r, self.out_features) 90 | output = torch.matmul(input, hyper_weight_A) # bsi,bir->bsr 91 | output = torch.matmul(output, hyper_weight_B) # bsr,bro->bso 92 | output = F.linear(input, self.weight, self.bias) + output * self.scaling 93 | 94 | return output 95 | 96 | 97 | class SkilledLoRALinear(SkilledModule): 98 | """ Applies a linear function parameterised by a base bias 99 | and a weighted average of base and skill weights 100 | """ 101 | __constants__ = ['in_features', 'out_features'] 102 | in_features: int 103 | out_features: int 104 | weight: Tensor 105 | 106 | def __init__(self, 107 | n_tasks: int, 108 | n_skills: int, 109 | skills: Optional[Tensor], 110 | weight: Tensor, 111 | bias: Optional[Tensor], 112 | r: int = 16, 113 | freeze: bool = True 114 | ) -> None: 115 | super().__init__() 116 | self.out_features, self.in_features = weight.shape 117 | self.r = r 118 | 119 | if skills is None: 120 | self.skill_logits = nn.Parameter(torch.empty((n_tasks, n_skills)).uniform_(-1e-3, 1e-3)) 121 | self.is_learned = True 122 | else: 123 | self.register_buffer("skill_logits", skills) 124 | self.is_learned = False 125 | 126 | self.weight = nn.Parameter(weight.data) 127 | self.weight.requires_grad = not freeze 128 | 129 | skills_weight_A = weight.new_empty((n_skills, r * self.in_features)) 130 | skills_weight_B = weight.new_empty((n_skills, self.out_features * r)) 131 | self.skills_weight_A = nn.Parameter(skills_weight_A) 132 | self.skills_weight_B = nn.Parameter(skills_weight_B) 133 | self.scaling = 1 / self.r 134 | 135 | if bias is not None: 136 | self.bias = nn.Parameter(bias.data) 137 | self.bias.requires_grad = not freeze 138 | else: 139 | self.register_parameter('bias', None) 140 | 141 | self.reset_parameters() 142 | 143 | def reset_parameters(self): 144 | gain = calculate_gain(nonlinearity="leaky_relu", param=math.sqrt(5)) 145 | std = gain / math.sqrt(self.in_features) 146 | with torch.no_grad(): 147 | self.skills_weight_A.uniform_(-std, std) 148 | torch.nn.init.zeros_(self.skills_weight_B) 149 | 150 | def forward(self, input: Tensor) -> Tensor: 151 | # Provisions for inputs repeated for generation 152 | assert input.size()[0] % self.task_ids.size(0) == 0 153 | repeats = input.size()[0] // self.task_ids.size(0) 154 | if repeats > 1: 155 | self.task_ids = torch.repeat_interleave(self.task_ids, repeats, dim=0) 156 | 157 | skill_logits = self.skill_logits[self.task_ids] 158 | if self.is_learned: 159 | if self.training: 160 | skill_logits = RelaxedBernoulli(temperature=1., logits=skill_logits).rsample() 161 | else: 162 | skill_logits = torch.sigmoid(skill_logits) 163 | skill_logits = skill_logits / (skill_logits.sum(dim=-1, keepdim=True) + EPS) 164 | 165 | skills_weight_A = torch.mm(skill_logits, self.skills_weight_A).view(input.size()[0], self.in_features, self.r) 166 | skills_weight_B = torch.mm(skill_logits, self.skills_weight_B).view(input.size()[0], self.r, self.out_features) 167 | output = torch.matmul(input, skills_weight_A) # bsi,bir->bsr 168 | output = torch.matmul(output, skills_weight_B) # bsr,bro->bso 169 | output = F.linear(input, self.weight, self.bias) + output * self.scaling 170 | 171 | return output 172 | 173 | 174 | class SkilledLTSFTLinear(SkilledModule): 175 | """ Applies a linear function parameterised by a base bias 176 | and a weighted average of base and skill weights 177 | """ 178 | __constants__ = ['in_features', 'out_features'] 179 | in_features: int 180 | out_features: int 181 | weight: Tensor 182 | 183 | def __init__(self, 184 | n_tasks: int, 185 | n_skills: int, 186 | skills: Optional[Tensor], 187 | weight: Tensor, 188 | bias: Optional[Tensor], 189 | density: float = 0.1, 190 | freeze: bool = True 191 | ) -> None: 192 | super().__init__() 193 | self.out_features, self.in_features = weight.shape 194 | 195 | if skills is None: 196 | self.skill_logits = nn.Parameter(torch.empty((n_tasks, n_skills)).uniform_(-1e-3, 1e-3)) 197 | self.is_learned = True 198 | else: 199 | self.register_buffer("skill_logits", skills) 200 | self.is_learned = False 201 | 202 | self.weight = nn.Parameter(weight.data) 203 | self.weight.requires_grad = not freeze 204 | 205 | indices = itertools.product(range(self.out_features * self.in_features), range(n_skills)) 206 | k = int(self.out_features * self.in_features * n_skills * density) 207 | indices = random.sample(list(indices), k=k) 208 | indices = torch.LongTensor(indices).T 209 | values = torch.zeros((k, )) 210 | skills_weight = torch.sparse_coo_tensor(indices, values, (self.out_features * self.in_features, n_skills)) 211 | self.skills_weight = nn.Parameter(skills_weight.coalesce()) 212 | 213 | if bias is not None: 214 | self.bias = nn.Parameter(bias.data) 215 | self.bias.requires_grad = not freeze 216 | else: 217 | self.register_parameter('bias', None) 218 | 219 | def forward(self, input: Tensor) -> Tensor: 220 | # Provisions for inputs repeated for generation 221 | assert input.size()[0] % self.task_ids.size(0) == 0 222 | repeats = input.size()[0] // self.task_ids.size(0) 223 | if repeats > 1: 224 | self.task_ids = torch.repeat_interleave(self.task_ids, repeats, dim=0) 225 | 226 | skill_logits = self.skill_logits[self.task_ids] 227 | if self.is_learned: 228 | if self.training: 229 | skill_logits = RelaxedBernoulli(temperature=1., logits=skill_logits).rsample() 230 | else: 231 | skill_logits = torch.sigmoid(skill_logits) 232 | skill_logits = skill_logits / (skill_logits.sum(dim=-1, keepdim=True) + EPS) 233 | 234 | skills_weight = torch.sparse.mm(self.skills_weight, skill_logits.T).T.view(input.size()[0], self.in_features, self.out_features) 235 | output = torch.matmul(input, skills_weight) # bsi,bio->bso 236 | output = F.linear(input, self.weight, self.bias) + output 237 | 238 | return output 239 | --------------------------------------------------------------------------------