├── README.md └── mmoe.py /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-mmoe 2 | 3 | This project is a re-implementation of MMoE [Modeling Task Relationships in Multi-task Learning with 4 | Multi-gate Mixture-of-Experts](https://dl.acm.org/doi/pdf/10.1145/3219819.3220007). The reference code is the keras version of MMoE: [keras-version](https://github.com/drawbridge/keras-mmoe) 5 | 6 | ## How to use 7 | ```python 8 | mmoe = MMoEModule(input_size, units, num_experts, num_tasks) 9 | 10 | output = mmoe(input) 11 | ``` 12 | 13 | 14 | -------------------------------------------------------------------------------- /mmoe.py: -------------------------------------------------------------------------------- 1 | """ 2 | Multi-gate Mixture-of-Experts model implementation (PyTorch). 3 | Written by Zhichen Zhao 4 | """ 5 | 6 | import pytorch_lightning as pl 7 | import torch 8 | 9 | class MMoEModule(pl.LightningModule):# if you are not using pytorch lightning, you can also use 'Module' 10 | def __init__(self, input_size, units, num_experts, num_tasks, use_cuda=True, use_expert_bias=False, use_gate_bias=False, expert_activation=None): 11 | super(MMoEModule, self).__init__() 12 | if use_cuda: 13 | self.expert_kernels = torch.nn.Parameter(torch.randn(input_size, units, num_experts, device='cuda'), requires_grad=True) 14 | self.gate_kernels = torch.nn.ParameterList([nn.Parameter(torch.randn(input_size, num_experts, device='cuda'), requires_grad=True) for i in range(num_tasks)]) 15 | 16 | self.expert_kernels_bias = torch.nn.Parameter(torch.randn(units, num_experts, device='cuda'), requires_grad=True) 17 | self.gate_kernels_bias = torch.nn.ParameterList([torch.nn.Parameter(torch.randn(num_experts, device='cuda'), requires_grad=True) for i in range(num_tasks)]) 18 | else: 19 | self.expert_kernels = torch.nn.Parameter(torch.randn(input_size, units, num_experts), requires_grad=True) 20 | self.gate_kernels = torch.nn.ParameterList([torch.nn.Parameter(torch.randn(input_size, num_experts), requires_grad=True) for i in range(num_tasks)]) 21 | 22 | self.expert_kernels_bias = torch.nn.Parameter(torch.randn(units, num_experts), requires_grad=True) 23 | self.gate_kernels_bias = torch.nn.ParameterList([torch.nn.Parameter(torch.randn(num_experts), requires_grad=True) for i in range(num_tasks)]) 24 | 25 | self.use_cuda = use_cuda 26 | self.use_expert_bias = use_expert_bias 27 | self.use_gate_bias = use_gate_bias 28 | self.expert_activation = expert_activation 29 | 30 | def forward(self, x): 31 | ''' 32 | x: input, (batch_size, input_size) 33 | expert_kernels: (input_size, units, num_experts) 34 | expert_kernels_bias: (units, num_experts) 35 | gate_kernels: (input_size, num_experts) 36 | gate_kernels_bias: (num_experts) 37 | final_outputs: output, a list len() == num_tasks, each element has shape of (batch_size, units) 38 | ''' 39 | 40 | gate_outputs = [] 41 | final_outputs = [] 42 | 43 | if self.use_cuda: 44 | x = x.cuda() 45 | 46 | expert_outputs = torch.einsum("ab,bcd->acd", (x, self.expert_kernels)) 47 | if self.use_expert_bias: 48 | expert_outputs += self.expert_kernels_bias 49 | 50 | if self.expert_activation is not None: 51 | expert_outputs = self.expert_activation(expert_outputs) 52 | 53 | for index, gate_kernel in enumerate(self.gate_kernels): 54 | gate_output = torch.einsum("ab,bc->ac", (x, gate_kernel)) 55 | if self.use_gate_bias: 56 | gate_output += self.gate_kernel_bias[index] 57 | gate_output = nn.Softmax(dim=-1)(gate_output) 58 | gate_outputs.append(gate_output) 59 | 60 | for gate_output in gate_outputs: 61 | expanded_gate_output = torch.unsqueeze(gate_output, 1) 62 | weighted_expert_output = expert_outputs * expanded_gate_output.expand_as(expert_outputs) 63 | final_outputs.append(torch.sum(weighted_expert_output, 2)) 64 | 65 | return final_outputs 66 | 67 | --------------------------------------------------------------------------------