├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md └── mtl_lib.py /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to AdaTT 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Our Development Process 6 | ... (in particular how this is synced with internal changes to the project) 7 | 8 | ## Pull Requests 9 | We actively welcome your pull requests. 10 | 11 | 1. Fork the repo and create your branch from `main`. 12 | 2. If you've added code that should be tested, add tests. 13 | 3. If you've changed APIs, update the documentation. 14 | 4. Ensure the test suite passes. 15 | 5. Make sure your code lints. 16 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 17 | 18 | ## Contributor License Agreement ("CLA") 19 | In order to accept your pull request, we need you to submit a CLA. You only need 20 | to do this once to work on any of Meta's open source projects. 21 | 22 | Complete your CLA here: 23 | 24 | ## Issues 25 | We use GitHub issues to track public bugs. Please ensure your description is 26 | clear and has sufficient instructions to be able to reproduce the issue. 27 | 28 | Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe 29 | disclosure of security bugs. In those cases, please go through the process 30 | outlined on that page and do not file a public issue. 31 | 32 | ## Coding Style 33 | * 2 spaces for indentation rather than tabs 34 | * 80 character line length 35 | * ... 36 | 37 | ## License 38 | By contributing to AdaTT, you agree that your contributions will be licensed 39 | under the LICENSE file in the root directory of this source tree. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | MIT License 3 | 4 | Copyright (c) Meta Platforms, Inc. and affiliates. 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AdaTT 2 | 3 | Welcome to the AdaTT repository! This repository provides a PyTorch library for multitask learning, specifically focused on the models evaluated in the paper ["AdaTT: Adaptive Task-to-Task Fusion Network for Multitask Learning in Recommendations" (KDD'23)"](https://doi.org/10.1145/3580305.3599769). 4 | 5 | [**[arXiv]**](https://arxiv.org/abs/2304.04959) [**[slides]**](https://drive.google.com/file/d/1I8XpxPxwhP9KXuztEguYkuMM10kiJDS7/view?usp=sharing) 6 | 7 | ## Models 8 | 9 | This repository implements the following models: 10 | 11 | + AdaTT [[Paper]](https://doi.org/10.1145/3580305.3599769) 12 | + MMoE [[Paper]](https://dl.acm.org/doi/10.1145/3219819.3220007) 13 | + Multi-level MMoE (an extension of MMoE) 14 | + PLE [[Paper]](https://doi.org/10.1145/3383313.3412236) 15 | + Cross-stitch [[Paper]](https://openaccess.thecvf.com/content_cvpr_2016/papers/Misra_Cross-Stitch_Networks_for_CVPR_2016_paper.pdf) 16 | + Shared-bottom [[Paper]](https://link.springer.com/article/10.1023/a:1007379606734) 17 | 18 | To facilitate the integration and selection of these models, we have implemented a class called `CentralTaskArch`. 19 | 20 | ## License 21 | 22 | AdaTT is MIT-licensed. 23 | 24 | ## Citation 25 | 26 | If you find AdaTT's paper or code helpful, please consider citing: 27 | ``` 28 | @article{li2023adatt, 29 | title={AdaTT: Adaptive Task-to-Task Fusion Network for Multitask Learning in Recommendations}, 30 | author={Li, Danwei and Zhang, Zhengyu and Yuan, Siyang and Gao, Mingze and Zhang, Weilin and Yang, Chaofei and Liu, Xi and Yang, Jiyan}, 31 | journal={arXiv preprint arXiv:2304.04959}, 32 | year={2023} 33 | } 34 | ``` 35 | -------------------------------------------------------------------------------- /mtl_lib.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | import logging 4 | from dataclasses import dataclass, field 5 | from math import sqrt 6 | from typing import List, Optional, Union 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | 12 | logger: logging.Logger = logging.getLogger(__name__) 13 | 14 | 15 | @dataclass 16 | class MtlConfigs: 17 | mtl_model: str = "att_sp" # consider using enum 18 | num_task_experts: int = 1 19 | num_shared_experts: int = 1 20 | expert_out_dims: List[List[int]] = field(default_factory=list) 21 | # self_exp_res_connect is crucial for model performance and is designed for 22 | # ablation studies. Ensure that it is always set to True for best results! 23 | self_exp_res_connect: bool = True 24 | expert_archs: Optional[List[List[int]]] = None 25 | gate_archs: Optional[List[List[int]]] = None 26 | num_experts: Optional[int] = None 27 | 28 | 29 | @dataclass(frozen=True) 30 | class ArchInputs: 31 | num_task: int = 3 32 | 33 | task_mlp: List[int] = field(default_factory=list) 34 | 35 | mtl_configs: Optional[MtlConfigs] = field(default=None) 36 | 37 | # Parameters related to activation function 38 | activation_type: str = "RELU" 39 | 40 | 41 | class AdaTTSp(nn.Module): 42 | """ 43 | paper title: "AdaTT: Adaptive Task-to-Task Fusion Network for Multitask Learning in Recommendations" 44 | paper link: https://doi.org/10.1145/3580305.3599769 45 | Call Args: 46 | inputs: inputs is a tensor of dimension 47 | [batch_size, self.num_tasks, self.input_dim]. 48 | Experts in the same module share the same input. 49 | outputs dimensions: [B, T, D_out] 50 | 51 | Example:: 52 | AdaTTSp( 53 | input_dim=256, 54 | expert_out_dims=[[128, 128]], 55 | num_tasks=8, 56 | num_task_experts=2, 57 | self_exp_res_connect=True, 58 | ) 59 | """ 60 | 61 | def __init__( 62 | self, 63 | input_dim: int, 64 | expert_out_dims: List[List[int]], 65 | num_tasks: int, 66 | num_task_experts: int, 67 | self_exp_res_connect: bool = True, 68 | activation: str = "RELU", 69 | ) -> None: 70 | super().__init__() 71 | if len(expert_out_dims) == 0: 72 | logger.warning( 73 | "AdaTTSp is noop! size of expert_out_dims which is the number of " 74 | "extraction layers should be at least 1." 75 | ) 76 | return 77 | self.num_extraction_layers: int = len(expert_out_dims) 78 | self.num_tasks = num_tasks 79 | self.num_task_experts = num_task_experts 80 | self.total_experts_per_layer: int = num_task_experts * num_tasks 81 | self.self_exp_res_connect = self_exp_res_connect 82 | self.experts = torch.nn.ModuleList() 83 | self.gate_weights = torch.nn.ModuleList() 84 | 85 | self_exp_weight_list = [] 86 | layer_input_dim = input_dim 87 | for expert_out_dim in expert_out_dims: 88 | self.experts.append( 89 | torch.nn.ModuleList( 90 | [ 91 | MLP(layer_input_dim, expert_out_dim, activation) 92 | for i in range(self.total_experts_per_layer) 93 | ] 94 | ) 95 | ) 96 | 97 | self.gate_weights.append( 98 | torch.nn.ModuleList( 99 | [ 100 | torch.nn.Sequential( 101 | torch.nn.Linear( 102 | layer_input_dim, self.total_experts_per_layer 103 | ), 104 | torch.nn.Softmax(dim=-1), 105 | ) 106 | for _ in range(num_tasks) 107 | ] 108 | ) 109 | ) # self.gate_weights is of shape L X T, after we loop over all layers. 110 | 111 | if self_exp_res_connect and num_task_experts > 1: 112 | params = torch.empty(num_tasks, num_task_experts) 113 | scale = sqrt(1.0 / num_task_experts) 114 | torch.nn.init.uniform_(params, a=-scale, b=scale) 115 | self_exp_weight_list.append(torch.nn.Parameter(params)) 116 | 117 | layer_input_dim = expert_out_dim[-1] 118 | 119 | self.self_exp_weights = nn.ParameterList(self_exp_weight_list) 120 | 121 | def forward( 122 | self, 123 | inputs: torch.Tensor, 124 | ) -> torch.Tensor: 125 | for layer_i in range(self.num_extraction_layers): 126 | # all task expert outputs. 127 | experts_out = torch.stack( 128 | [ 129 | expert(inputs[:, expert_i // self.num_task_experts, :]) 130 | for expert_i, expert in enumerate(self.experts[layer_i]) 131 | ], 132 | dim=1, 133 | ) # [B * E (total experts) * D_out] 134 | 135 | gates = torch.stack( 136 | [ 137 | gate_weight( 138 | inputs[:, task_i, :] 139 | ) # W ([B, D]) * S ([D, E]) -> G, dim is [B, E] 140 | for task_i, gate_weight in enumerate(self.gate_weights[layer_i]) 141 | ], 142 | dim=1, 143 | ) # [B, T, E] 144 | fused_experts_out = torch.bmm( 145 | gates, 146 | experts_out, 147 | ) # [B, T, E] X [B * E (total experts) * D_out] -> [B, T, D_out] 148 | 149 | if self.self_exp_res_connect: 150 | if self.num_task_experts > 1: 151 | # residual from the linear combination of tasks' own experts. 152 | self_exp_weighted = torch.einsum( 153 | "te,bted->btd", 154 | self.self_exp_weights[layer_i], 155 | experts_out.view( 156 | experts_out.size(0), 157 | self.num_tasks, 158 | self.num_task_experts, 159 | -1, 160 | ), # [B * E (total experts) * D_out] -> [B * T * E_task * D_out] 161 | ) # bmm: [T * E_task] X [B * T * E_task * D_out] -> [B, T, D_out] 162 | 163 | fused_experts_out = ( 164 | fused_experts_out + self_exp_weighted 165 | ) # [B, T, D_out] 166 | else: 167 | fused_experts_out = fused_experts_out + experts_out 168 | 169 | inputs = fused_experts_out 170 | 171 | return inputs 172 | 173 | 174 | class AdaTTWSharedExps(nn.Module): 175 | """ 176 | paper title: "AdaTT: Adaptive Task-to-Task Fusion Network for Multitask Learning in Recommendations" 177 | paper link: https://doi.org/10.1145/3580305.3599769 178 | Call Args: 179 | inputs: inputs is a tensor of dimension 180 | [batch_size, self.num_tasks, self.input_dim]. 181 | Experts in the same module share the same input. 182 | outputs dimensions: [B, T, D_out] 183 | 184 | Example:: 185 | AdaTTWSharedExps( 186 | input_dim=256, 187 | expert_out_dims=[[128, 128]], 188 | num_tasks=8, 189 | num_shared_experts=1, 190 | num_task_experts=2, 191 | self_exp_res_connect=True, 192 | ) 193 | """ 194 | 195 | def __init__( 196 | self, 197 | input_dim: int, 198 | expert_out_dims: List[List[int]], 199 | num_tasks: int, 200 | num_shared_experts: int, 201 | num_task_experts: Optional[int] = None, 202 | num_task_expert_list: Optional[List[int]] = None, 203 | # Set num_task_expert_list for experimenting with a flexible number of 204 | # experts for different task_specific units. 205 | self_exp_res_connect: bool = True, 206 | activation: str = "RELU", 207 | ) -> None: 208 | super().__init__() 209 | if len(expert_out_dims) == 0: 210 | logger.warning( 211 | "AdaTTWSharedExps is noop! size of expert_out_dims which is the number of " 212 | "extraction layers should be at least 1." 213 | ) 214 | return 215 | self.num_extraction_layers: int = len(expert_out_dims) 216 | self.num_tasks = num_tasks 217 | assert (num_task_experts is None) ^ (num_task_expert_list is None) 218 | if num_task_experts is not None: 219 | self.num_expert_list = [num_task_experts for _ in range(num_tasks)] 220 | else: 221 | # num_expert_list is guaranteed to be not None here. 222 | # pyre-ignore 223 | self.num_expert_list: List[int] = num_task_expert_list 224 | self.num_expert_list.append(num_shared_experts) 225 | 226 | self.total_experts_per_layer: int = sum(self.num_expert_list) 227 | self.self_exp_res_connect = self_exp_res_connect 228 | self.experts = torch.nn.ModuleList() 229 | self.gate_weights = torch.nn.ModuleList() 230 | 231 | layer_input_dim = input_dim 232 | for layer_i, expert_out_dim in enumerate(expert_out_dims): 233 | self.experts.append( 234 | torch.nn.ModuleList( 235 | [ 236 | MLP(layer_input_dim, expert_out_dim, activation) 237 | for i in range(self.total_experts_per_layer) 238 | ] 239 | ) 240 | ) 241 | 242 | num_full_active_modules = ( 243 | num_tasks 244 | if layer_i == self.num_extraction_layers - 1 245 | else num_tasks + 1 246 | ) 247 | 248 | self.gate_weights.append( 249 | torch.nn.ModuleList( 250 | [ 251 | torch.nn.Sequential( 252 | torch.nn.Linear( 253 | layer_input_dim, self.total_experts_per_layer 254 | ), 255 | torch.nn.Softmax(dim=-1), 256 | ) 257 | for _ in range(num_full_active_modules) 258 | ] 259 | ) 260 | ) # self.gate_weights is a 2d module list of shape L X T (+ 1), after we loop over all layers. 261 | 262 | layer_input_dim = expert_out_dim[-1] 263 | 264 | self_exp_weight_list = [] 265 | if self_exp_res_connect: 266 | # If any tasks have number of experts not equal to 1, we learn linear combinations of native experts. 267 | if any(num_experts != 1 for num_experts in self.num_expert_list): 268 | for i in range(num_tasks + 1): 269 | num_full_active_layer = ( 270 | self.num_extraction_layers - 1 271 | if i == num_tasks 272 | else self.num_extraction_layers 273 | ) 274 | params = torch.empty( 275 | num_full_active_layer, 276 | self.num_expert_list[i], 277 | ) 278 | scale = sqrt(1.0 / self.num_expert_list[i]) 279 | torch.nn.init.uniform_(params, a=-scale, b=scale) 280 | self_exp_weight_list.append(torch.nn.Parameter(params)) 281 | 282 | self.self_exp_weights = nn.ParameterList(self_exp_weight_list) 283 | 284 | self.expert_input_idx: List[int] = [] 285 | for i in range(num_tasks + 1): 286 | self.expert_input_idx.extend([i for _ in range(self.num_expert_list[i])]) 287 | 288 | def forward( 289 | self, 290 | inputs: torch.Tensor, 291 | ) -> torch.Tensor: 292 | for layer_i in range(self.num_extraction_layers): 293 | num_full_active_modules = ( 294 | self.num_tasks 295 | if layer_i == self.num_extraction_layers - 1 296 | else self.num_tasks + 1 297 | ) 298 | # all task expert outputs. 299 | experts_out = torch.stack( 300 | [ 301 | expert(inputs[:, self.expert_input_idx[expert_i], :]) 302 | for expert_i, expert in enumerate(self.experts[layer_i]) 303 | ], 304 | dim=1, 305 | ) # [B * E (total experts) * D_out] 306 | 307 | # gate weights for fusing all experts. 308 | gates = torch.stack( 309 | [ 310 | gate_weight(inputs[:, i, :]) # [B, D] * [D, E] -> [B, E] 311 | for i, gate_weight in enumerate(self.gate_weights[layer_i]) 312 | ], 313 | dim=1, 314 | ) # [B, T (+ 1), E] 315 | 316 | # add all expert gate weights with native expert weights. 317 | if self.self_exp_res_connect: 318 | prev_idx = 0 319 | use_unit_naive_weights = all( 320 | num_expert == 1 for num_expert in self.num_expert_list 321 | ) 322 | for module_i in range(num_full_active_modules): 323 | next_idx = self.num_expert_list[module_i] + prev_idx 324 | if use_unit_naive_weights: 325 | gates[:, module_i, prev_idx:next_idx] += torch.ones( 326 | 1, self.num_expert_list[module_i] 327 | ) 328 | else: 329 | gates[:, module_i, prev_idx:next_idx] += self.self_exp_weights[ 330 | module_i 331 | ][layer_i].unsqueeze(0) 332 | prev_idx = next_idx 333 | 334 | fused_experts_out = torch.bmm( 335 | gates, 336 | experts_out, 337 | ) # [B, T (+ 1), E (total)] X [B * E (total) * D_out] -> [B, T (+ 1), D_out] 338 | 339 | inputs = fused_experts_out 340 | 341 | return inputs 342 | 343 | 344 | class MLP(nn.Module): 345 | """ 346 | Args: 347 | input_dim (int): 348 | mlp_arch (List[int]): 349 | activation (str): 350 | 351 | Call Args: 352 | input (torch.Tensor): tensor of shape (B, I) 353 | 354 | Returns: 355 | output (torch.Tensor): MLP result 356 | 357 | Example:: 358 | 359 | mlp = MLP(100, [100]) 360 | 361 | """ 362 | 363 | def __init__( 364 | self, 365 | input_dim: int, 366 | mlp_arch: List[int], 367 | activation: str = "RELU", 368 | bias: bool = True, 369 | ) -> None: 370 | super().__init__() 371 | 372 | mlp_net = [] 373 | for mlp_dim in mlp_arch: 374 | mlp_net.append( 375 | nn.Linear(in_features=input_dim, out_features=mlp_dim, bias=bias) 376 | ) 377 | if activation == "RELU": 378 | mlp_net.append(nn.ReLU()) 379 | else: 380 | raise ValueError("only RELU is included currently") 381 | input_dim = mlp_dim 382 | self.mlp_net = nn.Sequential(*mlp_net) 383 | 384 | def forward( 385 | self, 386 | input: torch.Tensor, 387 | ) -> torch.Tensor: 388 | return self.mlp_net(input) 389 | 390 | 391 | class SharedBottom(nn.Module): 392 | def __init__( 393 | self, input_dim: int, hidden_dims: List[int], num_tasks: int, activation: str 394 | ) -> None: 395 | super().__init__() 396 | self.bottom_projection = MLP(input_dim, hidden_dims, activation) 397 | self.num_tasks: int = num_tasks 398 | 399 | def forward( 400 | self, 401 | input: torch.Tensor, 402 | ) -> torch.Tensor: 403 | # input dim [T, D_in] 404 | # output dim [B, T, D_out] 405 | return self.bottom_projection(input).unsqueeze(1).expand(-1, self.num_tasks, -1) 406 | 407 | 408 | class CrossStitch(torch.nn.Module): 409 | """ 410 | cross-stitch 411 | paper title: "Cross-stitch Networks for Multi-task Learning". 412 | paper link: https://openaccess.thecvf.com/content_cvpr_2016/papers/Misra_Cross-Stitch_Networks_for_CVPR_2016_paper.pdf 413 | """ 414 | 415 | def __init__( 416 | self, 417 | input_dim: int, 418 | expert_archs: List[List[int]], 419 | num_tasks: int, 420 | activation: str = "RELU", 421 | ) -> None: 422 | super().__init__() 423 | self.num_layers: int = len(expert_archs) 424 | self.num_tasks = num_tasks 425 | self.experts = torch.nn.ModuleList() 426 | self.stitchs = torch.nn.ModuleList() 427 | 428 | expert_input_dim = input_dim 429 | for layer_ind in range(self.num_layers): 430 | self.experts.append( 431 | torch.nn.ModuleList( 432 | [ 433 | MLP( 434 | expert_input_dim, 435 | expert_archs[layer_ind], 436 | activation, 437 | ) 438 | for _ in range(self.num_tasks) 439 | ] 440 | ) 441 | ) 442 | 443 | self.stitchs.append( 444 | torch.nn.Linear( 445 | self.num_tasks, 446 | self.num_tasks, 447 | bias=False, 448 | ) 449 | ) 450 | 451 | expert_input_dim = expert_archs[layer_ind][-1] 452 | 453 | def forward(self, input: torch.Tensor) -> torch.Tensor: 454 | """ 455 | input dim [B, T, D_in] 456 | output dim [B, T, D_out] 457 | """ 458 | x = input 459 | 460 | for layer_ind in range(self.num_layers): 461 | expert_out = torch.stack( 462 | [ 463 | expert(x[:, expert_ind, :]) # [B, D_out] 464 | for expert_ind, expert in enumerate(self.experts[layer_ind]) 465 | ], 466 | dim=1, 467 | ) # [B, T, D_out] 468 | 469 | stitch_out = self.stitchs[layer_ind](expert_out.transpose(1, 2)).transpose( 470 | 1, 2 471 | ) # [B, T, D_out] 472 | 473 | x = stitch_out 474 | 475 | return x 476 | 477 | 478 | class MLMMoE(torch.nn.Module): 479 | """ 480 | Multi-level Multi-gate Mixture of Experts 481 | This code implements a multi-level extension of the MMoE model, as described in the 482 | paper titled "Modeling Task Relationships in Multi-task Learning with Multi-gate 483 | Mixture-of-Experts". 484 | Paper link: https://dl.acm.org/doi/10.1145/3219819.3220007 485 | To run the original MMoE, use only one fusion level. For example, set expert_archs as 486 | [[96, 48]]. 487 | To configure multiple fusion levels, set expert_archs as something like [[96], [48]]. 488 | """ 489 | 490 | def __init__( 491 | self, 492 | input_dim: int, 493 | expert_archs: List[List[int]], 494 | gate_archs: List[List[int]], 495 | num_tasks: int, 496 | num_experts: int, 497 | activation: str = "RELU", 498 | ) -> None: 499 | super().__init__() 500 | self.num_layers: int = len(expert_archs) 501 | self.num_tasks: int = num_tasks 502 | self.num_experts = num_experts 503 | self.experts = torch.nn.ModuleList() 504 | self.gates = torch.nn.ModuleList() 505 | 506 | expert_input_dim = input_dim 507 | for layer_ind in range(self.num_layers): 508 | self.experts.append( 509 | torch.nn.ModuleList( 510 | [ 511 | MLP( 512 | expert_input_dim, 513 | expert_archs[layer_ind], 514 | activation, 515 | ) 516 | for _ in range(self.num_experts) 517 | ] 518 | ) 519 | ) 520 | self.gates.append( 521 | torch.nn.ModuleList( 522 | [ 523 | torch.nn.Sequential( 524 | MLP( 525 | input_dim, 526 | gate_archs[layer_ind], 527 | activation, 528 | ), 529 | torch.nn.Linear( 530 | gate_archs[layer_ind][-1] 531 | if gate_archs[layer_ind] 532 | else input_dim, 533 | self.num_experts, 534 | ), 535 | torch.nn.Softmax(dim=-1), 536 | ) 537 | for _ in range( 538 | self.num_experts 539 | if layer_ind < self.num_layers - 1 540 | else self.num_tasks 541 | ) 542 | ] 543 | ) 544 | ) 545 | expert_input_dim = expert_archs[layer_ind][-1] 546 | 547 | def forward(self, input: torch.Tensor) -> torch.Tensor: 548 | """ 549 | input dim [B, D_in] 550 | output dim [B, T, D_out] 551 | """ 552 | x = input.unsqueeze(1).expand([-1, self.num_experts, -1]) # [B, E, D_in] 553 | 554 | for layer_ind in range(self.num_layers): 555 | expert_out = torch.stack( 556 | [ 557 | expert(x[:, expert_ind, :]) # [B, D_out] 558 | for expert_ind, expert in enumerate(self.experts[layer_ind]) 559 | ], 560 | dim=1, 561 | ) # [B, E, D_out] 562 | 563 | gate_out = torch.stack( 564 | [ 565 | gate(input) # [B, E] 566 | for gate_ind, gate in enumerate(self.gates[layer_ind]) 567 | ], 568 | dim=1, 569 | ) # [B, T, E] 570 | 571 | gated_out = torch.matmul(gate_out, expert_out) # [B, T, D_out] 572 | 573 | x = gated_out 574 | return x 575 | 576 | 577 | class PLE(nn.Module): 578 | """ 579 | PLE module is based on the paper "Progressive Layered Extraction (PLE): A 580 | Novel Multi-Task Learning (MTL) Model for Personalized Recommendations". 581 | Paper link: https://doi.org/10.1145/3383313.3412236 582 | PLE aims to address negative transfer and seesaw phenomenon in multi-task 583 | learning. PLE distinguishes shared and task-specic experts explicitly and 584 | adopts a progressive routing mechanism to extract and separate deeper 585 | semantic knowledge gradually. When there is only one extraction layer, PLE 586 | falls back to CGC. 587 | 588 | Args: 589 | input_dim: input embedding dimension 590 | expert_out_dims (List[List[int]]): dimension of an expert's output at 591 | each layer. This list's length equals the number of extraction 592 | layers 593 | num_tasks: number of tasks 594 | num_task_experts: number of experts for each task module at each layer. 595 | * If the number of experts is the same for all tasks, use an 596 | integer here. 597 | * If the number of experts is different for different tasks, use a 598 | list of integers here. 599 | num_shared_experts: number of experts for shared module at each layer 600 | 601 | Call Args: 602 | inputs: inputs is a tensor of dimension [batch_size, self.num_tasks + 1, 603 | self.input_dim]. Task specific module inputs are placed first, followed 604 | by shared module input. (Experts in the same module share the same input) 605 | 606 | Returns: 607 | output: output of extraction layer to be feed into task-specific tower 608 | networks. It's a list of tensors, each of which is for one task. 609 | 610 | Example:: 611 | PLE( 612 | input_dim=256, 613 | expert_out_dims=[[128]], 614 | num_tasks=8, 615 | num_task_experts=2, 616 | num_shared_experts=2, 617 | ) 618 | 619 | """ 620 | 621 | def __init__( 622 | self, 623 | input_dim: int, 624 | expert_out_dims: List[List[int]], 625 | num_tasks: int, 626 | num_task_experts: Union[int, List[int]], 627 | num_shared_experts: int, 628 | activation: str = "RELU", 629 | ) -> None: 630 | super().__init__() 631 | if len(expert_out_dims) == 0: 632 | raise ValueError("Expert out dims cannot be empty list") 633 | self.num_extraction_layers: int = len(expert_out_dims) 634 | self.num_tasks = num_tasks 635 | self.num_task_experts = num_task_experts 636 | if type(num_task_experts) is int: 637 | self.total_experts_per_layer: int = ( 638 | num_task_experts * num_tasks + num_shared_experts 639 | ) 640 | else: 641 | self.total_experts_per_layer: int = ( 642 | sum(num_task_experts) + num_shared_experts 643 | ) 644 | assert len(num_task_experts) == num_tasks 645 | self.num_shared_experts = num_shared_experts 646 | self.experts = nn.ModuleList() 647 | expert_input_dim = input_dim 648 | for expert_out_dim in expert_out_dims: 649 | self.experts.append( 650 | nn.ModuleList( 651 | [ 652 | MLP(expert_input_dim, expert_out_dim, activation) 653 | for i in range(self.total_experts_per_layer) 654 | ] 655 | ) 656 | ) 657 | expert_input_dim = expert_out_dim[-1] 658 | 659 | self.gate_weights = nn.ModuleList() 660 | selector_dim = input_dim 661 | for i in range(self.num_extraction_layers): 662 | expert_out_dim = expert_out_dims[i] 663 | # task specific gates. 664 | if type(num_task_experts) is int: 665 | gate_weights_in_layer = nn.ModuleList( 666 | [ 667 | nn.Sequential( 668 | nn.Linear( 669 | selector_dim, num_task_experts + num_shared_experts 670 | ), 671 | nn.Softmax(dim=-1), 672 | ) 673 | for i in range(num_tasks) 674 | ] 675 | ) 676 | else: 677 | gate_weights_in_layer = nn.ModuleList( 678 | [ 679 | nn.Sequential( 680 | nn.Linear( 681 | selector_dim, num_task_experts[i] + num_shared_experts 682 | ), 683 | nn.Softmax(dim=-1), 684 | ) 685 | for i in range(num_tasks) 686 | ] 687 | ) 688 | # Shared module gates. Note last layer has only task specific module gates for task towers later. 689 | if i != self.num_extraction_layers - 1: 690 | gate_weights_in_layer.append( 691 | nn.Sequential( 692 | nn.Linear(selector_dim, self.total_experts_per_layer), 693 | nn.Softmax(dim=-1), 694 | ) 695 | ) 696 | self.gate_weights.append(gate_weights_in_layer) 697 | 698 | selector_dim = expert_out_dim[-1] 699 | 700 | if type(self.num_task_experts) is list: 701 | experts_idx_2_task_idx = [] 702 | for i in range(num_tasks): 703 | # pyre-ignore 704 | experts_idx_2_task_idx += [i] * self.num_task_experts[i] 705 | experts_idx_2_task_idx += [num_tasks] * num_shared_experts 706 | self.experts_idx_2_task_idx: List[int] = experts_idx_2_task_idx 707 | 708 | def forward( 709 | self, 710 | inputs: torch.Tensor, 711 | ) -> torch.Tensor: 712 | for layer_i in range(self.num_extraction_layers): 713 | # all task specific and shared experts' outputs. 714 | # Note first num_task_experts * num_tasks experts are task specific, 715 | # last num_shared_experts experts are shared. 716 | if type(self.num_task_experts) is int: 717 | experts_out = torch.stack( 718 | [ 719 | self.experts[layer_i][expert_i]( 720 | inputs[ 721 | :, 722 | # pyre-ignore 723 | min(expert_i // self.num_task_experts, self.num_tasks), 724 | :, 725 | ] 726 | ) 727 | for expert_i in range(self.total_experts_per_layer) 728 | ], 729 | dim=1, 730 | ) # [B * E (num experts) * D_out] 731 | else: 732 | experts_out = torch.stack( 733 | [ 734 | self.experts[layer_i][expert_i]( 735 | inputs[ 736 | :, 737 | self.experts_idx_2_task_idx[expert_i], 738 | :, 739 | ] 740 | ) 741 | for expert_i in range(self.total_experts_per_layer) 742 | ], 743 | dim=1, 744 | ) # [B * E (num experts) * D_out] 745 | 746 | gates_out = [] 747 | # Loop over all the gates in the layer. Note for the last layer, 748 | # there is no shared gating network. 749 | prev_idx = 0 750 | for gate_i in range(len(self.gate_weights[layer_i])): 751 | # This is for shared gating network, which uses all the experts. 752 | if gate_i == self.num_tasks: 753 | selected_matrix = experts_out # S_share 754 | # This is for task gating network, which only uses shared and its own experts. 755 | else: 756 | if type(self.num_task_experts) is int: 757 | task_experts_out = experts_out[ 758 | :, 759 | # pyre-ignore 760 | (gate_i * self.num_task_experts) : (gate_i + 1) 761 | # pyre-ignore 762 | * self.num_task_experts, 763 | :, 764 | ] # task specific experts 765 | else: 766 | # pyre-ignore 767 | next_idx = prev_idx + self.num_task_experts[gate_i] 768 | task_experts_out = experts_out[ 769 | :, 770 | prev_idx:next_idx, 771 | :, 772 | ] # task specific experts 773 | prev_idx = next_idx 774 | shared_experts_out = experts_out[ 775 | :, 776 | -self.num_shared_experts :, 777 | :, 778 | ] # shared experts 779 | selected_matrix = torch.concat( 780 | [task_experts_out, shared_experts_out], dim=1 781 | ) # S_k with dimension of [B * E_selected * D_out] 782 | 783 | gates_out.append( 784 | torch.bmm( 785 | self.gate_weights[layer_i][gate_i]( 786 | inputs[:, gate_i, :] 787 | ).unsqueeze(dim=1), 788 | selected_matrix, 789 | ) 790 | # W * S -> G 791 | # [B, 1, E_selected] X [B * E_selected * D_out] -> [B, 1, D_out] 792 | ) 793 | inputs = torch.cat(gates_out, dim=1) # [B, T, D_out] 794 | 795 | return inputs 796 | 797 | 798 | class CentralTaskArch(nn.Module): 799 | def __init__( 800 | self, 801 | mtl_configs: MtlConfigs, 802 | opts: ArchInputs, 803 | input_dim: int, 804 | ) -> None: 805 | super().__init__() 806 | self.opts = opts 807 | 808 | assert len(mtl_configs.expert_out_dims) > 0, "expert_out_dims is empty." 809 | self.num_tasks: int = opts.num_task 810 | 811 | self.mtl_model: str = mtl_configs.mtl_model 812 | logger.info(f"mtl_model is {mtl_configs.mtl_model}") 813 | expert_out_dims: List[List[int]] = mtl_configs.expert_out_dims 814 | # AdaTT-sp 815 | # consider consolidating the implementation of att_sp and att_g. 816 | if mtl_configs.mtl_model == "att_sp": 817 | self.mtl_arch: nn.Module = AdaTTSp( 818 | input_dim=input_dim, 819 | expert_out_dims=expert_out_dims, 820 | num_tasks=self.num_tasks, 821 | num_task_experts=mtl_configs.num_task_experts, 822 | self_exp_res_connect=mtl_configs.self_exp_res_connect, 823 | activation=opts.activation_type, 824 | ) 825 | # AdaTT-general 826 | elif mtl_configs.mtl_model == "att_g": 827 | self.mtl_arch: nn.Module = AdaTTWSharedExps( 828 | input_dim=input_dim, 829 | expert_out_dims=expert_out_dims, 830 | num_tasks=self.num_tasks, 831 | num_task_experts=mtl_configs.num_task_experts, 832 | num_shared_experts=mtl_configs.num_shared_experts, 833 | self_exp_res_connect=mtl_configs.self_exp_res_connect, 834 | activation=opts.activation_type, 835 | ) 836 | # PLE 837 | elif mtl_configs.mtl_model == "ple": 838 | self.mtl_arch: nn.Module = PLE( 839 | input_dim=input_dim, 840 | expert_out_dims=expert_out_dims, 841 | num_tasks=self.num_tasks, 842 | num_task_experts=mtl_configs.num_task_experts, 843 | num_shared_experts=mtl_configs.num_shared_experts, 844 | activation=opts.activation_type, 845 | ) 846 | # cross-stitch 847 | elif mtl_configs.mtl_model == "cross_st": 848 | self.mtl_arch: nn.Module = CrossStitch( 849 | input_dim=input_dim, 850 | expert_archs=mtl_configs.expert_out_dims, 851 | num_tasks=self.num_tasks, 852 | activation=opts.activation_type, 853 | ) 854 | # multi-layer MMoE or MMoE 855 | elif mtl_configs.mtl_model == "mmoe": 856 | self.mtl_arch: nn.Module = MLMMoE( 857 | input_dim=input_dim, 858 | expert_archs=mtl_configs.expert_out_dims, 859 | gate_archs=[[] for i in range(len(mtl_configs.expert_out_dims))], 860 | num_tasks=self.num_tasks, 861 | num_experts=mtl_configs.num_shared_experts, 862 | activation=opts.activation_type, 863 | ) 864 | # shared bottom 865 | elif mtl_configs.mtl_model == "share_bottom": 866 | self.mtl_arch: nn.Module = SharedBottom( 867 | input_dim, 868 | [dim for dims in expert_out_dims for dim in dims], 869 | self.num_tasks, 870 | opts.activation_type, 871 | ) 872 | else: 873 | raise ValueError("invalid model type") 874 | 875 | task_modules_input_dim = expert_out_dims[-1][-1] 876 | self.task_modules: nn.ModuleList = nn.ModuleList( 877 | [ 878 | nn.Sequential( 879 | MLP( 880 | task_modules_input_dim, self.opts.task_mlp, opts.activation_type 881 | ), 882 | torch.nn.Linear(self.opts.task_mlp[-1], 1), 883 | ) 884 | for i in range(self.num_tasks) 885 | ] 886 | ) 887 | 888 | def forward( 889 | self, 890 | task_arch_input: torch.Tensor, 891 | ) -> List[torch.Tensor]: 892 | if self.mtl_model in ["att_sp", "cross_st"]: 893 | task_arch_input = task_arch_input.unsqueeze(1).expand( 894 | -1, self.num_tasks, -1 895 | ) 896 | elif self.mtl_model in ["att_g", "ple"]: 897 | task_arch_input = task_arch_input.unsqueeze(1).expand( 898 | -1, self.num_tasks + 1, -1 899 | ) 900 | 901 | task_specific_outputs = self.mtl_arch(task_arch_input) 902 | 903 | task_arch_output = [ 904 | task_module(task_specific_outputs[:, i, :]) 905 | for i, task_module in enumerate(self.task_modules) 906 | ] 907 | 908 | return task_arch_output 909 | --------------------------------------------------------------------------------