├── .gitignore ├── README.md ├── data ├── adult.data ├── adult.names └── adult.test ├── esmm.py ├── main.py ├── mmoe.py ├── model ├── model_esmm_1 └── model_mmoe_1 ├── model_train.py ├── pic ├── ESMM.png └── MMOE.png └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 多任务模型 2 | ## 多任务的优势 3 | * 相较于单任务,多任务不容易过拟合,因为多任务的损失函数同时受到多个任务的loss约束,这样会制约单个任务的性能; 4 | * 多任务训练更加优雅,对资源的消耗更少,试想如果我们多任务有3个,变成独立的3个单任务进行训练的话会用数据集训练三个模型,这样会需要大量的资源进行三个模型的训练,资源利用率较低 5 | * 缓解稀疏性,单个任务的数据一般存在严重的稀疏性即正负样本比率严重不平衡,导致的结果是单个任务的训练效果会比较差,利用多任务的共同训练机制及share bottom机制可以利用其他任务的训练数据对底层embedding进行修正,一定程度上会缓解稀疏性 6 | * 修正样本偏差问题,由于单个任务的目的不一样,因此单任务选择的样本存在不一致的问题,例如:CTR任务选择的样本是曝光点击,CVR任务选择的样本是点击到购买,CVR任务的样本仅仅是CTR任务样本的一个子集,如果独立训练的话会造成单个模型训练结果的有偏,而多任务可以在一个样本空间进行训练,便于修正样本空间的问题。 7 | 8 | ## 多任务的一些问题 9 | * 单个任务的性能容易受到限制,任务loss为多个任务loss的加权,对于单个任务的loss训练存在不充分的情况 10 | * 任务之间存在蹊跷板现象及一个任务训练效果很好,另外一个很差或者没特别差别 11 | * 多任务联合loss的方式无法断定出当前的loss为每个独立任务最好loss的加权 12 | 13 | ## ESMM模型 14 | ![avatar](./pic/ESMM.png) 15 | 16 | ### 模型解读 17 | 模型主要有一个双塔的形式分别构建user和item相关的特征,然后进行embedding和拼接,最终输出到各个task的NN网络中进行学习,整体结构非常简单清晰,这里不一一细说。主要强调几个需要注意和可以优化的点: 18 | * user filed和item filed侧的特征如何融合,图上写的比较模糊,这里我的理解为:分别对user的每个特征进行向量化,例如类别特征直接embedding为一个向量,序列特征需要进行融合成为一个embedding(具体融合可以自己定义实验,如:maxpool,avgpool,sum都可以尝试),然后各个特征embedding在进行element-wise级别的加和操作,可以得到一个用户信息的整体embedding,item相关的信息处理类似,这里我们需要注意的是其中是没有考虑一些numeric特征的,如果有numeric特征我们需要在concatenate layer进行拼接上去 19 | * 模型的训练的实际loss = loss_{ctr} + loss_{ctcvr}论文中加号右半部分写的是loss_{ctr}*loss_{cvr}这个容易让人搞混,因为实际上在整个样本空间上CVR是没有办法进行训练的,所以cvrloss根本计算不了,因为给定的输入数据是全样本空间即曝光数据,所以从曝光到点击的数据是CTR,从曝光到购买的数据是CTCVR这一点需要理解清楚; 20 | * 可以尝试优化的点:1. 序列特征的处理技巧,attention等的加入;2. loss的加权方式需要根据实际训练结果调整,因为不同任务的正负样本不一致,因此每个任务的loss都不一样,如果两个loss数量级差距很大,较大的loss会主导loss的优化,导致数量级较小的那个loss训练不充分; 21 | 22 | ## MMOE模型 23 | ![avatar](./pic/MMOE.png) 24 | 25 | ### 模型解读 26 | 模型跟ESMM大体框架基本一致,这里说一下主要的区别:其实图上说的也比较直观,大致就是传统的ESMM底层是直接共享一个share bottom的结构,这里一般是embedding层的共享,这种共享会让各个任务直接彼此受到很强的束缚,不能完全发挥单任务的最大价值;这里的改进是提出了一种expert机制,利用多个expert分别进行加权融合的形式去表达每个单任务,这样一定程度能缓解ESMM那找直接share bottom的束缚; 27 | 在深入一点就是,我们在得到融合后的input embedding(这里可以理解为user embedding和item embedding拼接后的整体embedding),然后加入多个结构相同的DNN网络进行学习,虽然DNN网络结构相同,但是参数在初始化及反向传播迭代的时候会有一定的差别,相当于多个相同结构的DNN同时训练一批数据预测的结果,然后这里添加了一个门机制对各个DNN的输出进行加权处理得到最终的输出,最后塞入单个任务塔中得到最终的结果; 28 | 个人理解MMOE其实就是集成版本的ESMM但是这种内部模块的集成非不同模型的集成其效果有待验证,了解集成的应该都清楚,模型的差异化带来的集成效果较为明显但是模型相同参数不同的集成带来的效果一般都很小。 29 | 30 | 31 | ## 实验情况 32 | ###数据说明 33 | 选用的是uci的数据集https://archive.ics.uci.edu/ml/datasets/census+income 34 | 35 | ### 特征处理 36 | 数值特征进行归一化,类别特征进行embedding操作,由于没有区分user和item的特征,这里直接取前面7个特征为user特征后面的特征为item特征,然后进行向量化拼接 37 | 预测值分别为**income_50k**和**marital_status** 38 | 39 | ### 实验结果 40 | | 模型 | auc_income | auc_marital | 41 | | ---- | ---- | ---- | 42 | | ESMM | 90.2% | 96.2% | 43 | | MMOE | 89.6% | 96.1% | 44 | 可以看到在这个小数据集上ESMM效果于MMOE差别很小,较大数据集还未实验 -------------------------------------------------------------------------------- /data/adult.names: -------------------------------------------------------------------------------- 1 | | This data was extracted from the census bureau database found at 2 | | http://www.census.gov/ftp/pub/DES/www/welcome.html 3 | | Donor: Ronny Kohavi and Barry Becker, 4 | | Data Mining and Visualization 5 | | Silicon Graphics. 6 | | e-mail: ronnyk@sgi.com for questions. 7 | | Split into train-test using MLC++ GenCVFiles (2/3, 1/3 random). 8 | | 48842 instances, mix of continuous and discrete (train=32561, test=16281) 9 | | 45222 if instances with unknown values are removed (train=30162, test=15060) 10 | | Duplicate or conflicting instances : 6 11 | | Class probabilities for adult.all file 12 | | Probability for the label '>50K' : 23.93% / 24.78% (without unknowns) 13 | | Probability for the label '<=50K' : 76.07% / 75.22% (without unknowns) 14 | | 15 | | Extraction was done by Barry Becker from the 1994 Census database. A set of 16 | | reasonably clean records was extracted using the following conditions: 17 | | ((AAGE>16) && (AGI>100) && (AFNLWGT>1)&& (HRSWK>0)) 18 | | 19 | | Prediction task is to determine whether a person makes over 50K 20 | | a year. 21 | | 22 | | First cited in: 23 | | @inproceedings{kohavi-nbtree, 24 | | author={Ron Kohavi}, 25 | | title={Scaling Up the Accuracy of Naive-Bayes Classifiers: a 26 | | Decision-Tree Hybrid}, 27 | | booktitle={Proceedings of the Second International Conference on 28 | | Knowledge Discovery and Data Mining}, 29 | | year = 1996, 30 | | pages={to appear}} 31 | | 32 | | Error Accuracy reported as follows, after removal of unknowns from 33 | | train/test sets): 34 | | C4.5 : 84.46+-0.30 35 | | Naive-Bayes: 83.88+-0.30 36 | | NBTree : 85.90+-0.28 37 | | 38 | | 39 | | Following algorithms were later run with the following error rates, 40 | | all after removal of unknowns and using the original train/test split. 41 | | All these numbers are straight runs using MLC++ with default values. 42 | | 43 | | Algorithm Error 44 | | -- ---------------- ----- 45 | | 1 C4.5 15.54 46 | | 2 C4.5-auto 14.46 47 | | 3 C4.5 rules 14.94 48 | | 4 Voted ID3 (0.6) 15.64 49 | | 5 Voted ID3 (0.8) 16.47 50 | | 6 T2 16.84 51 | | 7 1R 19.54 52 | | 8 NBTree 14.10 53 | | 9 CN2 16.00 54 | | 10 HOODG 14.82 55 | | 11 FSS Naive Bayes 14.05 56 | | 12 IDTM (Decision table) 14.46 57 | | 13 Naive-Bayes 16.12 58 | | 14 Nearest-neighbor (1) 21.42 59 | | 15 Nearest-neighbor (3) 20.35 60 | | 16 OC1 15.04 61 | | 17 Pebls Crashed. Unknown why (bounds WERE increased) 62 | | 63 | | Conversion of original data as follows: 64 | | 1. Discretized agrossincome into two ranges with threshold 50,000. 65 | | 2. Convert U.S. to US to avoid periods. 66 | | 3. Convert Unknown to "?" 67 | | 4. Run MLC++ GenCVFiles to generate data,test. 68 | | 69 | | Description of fnlwgt (final weight) 70 | | 71 | | The weights on the CPS files are controlled to independent estimates of the 72 | | civilian noninstitutional population of the US. These are prepared monthly 73 | | for us by Population Division here at the Census Bureau. We use 3 sets of 74 | | controls. 75 | | These are: 76 | | 1. A single cell estimate of the population 16+ for each state. 77 | | 2. Controls for Hispanic Origin by age and sex. 78 | | 3. Controls by Race, age and sex. 79 | | 80 | | We use all three sets of controls in our weighting program and "rake" through 81 | | them 6 times so that by the end we come back to all the controls we used. 82 | | 83 | | The term estimate refers to population totals derived from CPS by creating 84 | | "weighted tallies" of any specified socio-economic characteristics of the 85 | | population. 86 | | 87 | | People with similar demographic characteristics should have 88 | | similar weights. There is one important caveat to remember 89 | | about this statement. That is that since the CPS sample is 90 | | actually a collection of 51 state samples, each with its own 91 | | probability of selection, the statement only applies within 92 | | state. 93 | 94 | 95 | >50K, <=50K. 96 | 97 | age: continuous. 98 | workclass: Private, Self-emp-not-inc, Self-emp-inc, Federal-gov, Local-gov, State-gov, Without-pay, Never-worked. 99 | fnlwgt: continuous. 100 | education: Bachelors, Some-college, 11th, HS-grad, Prof-school, Assoc-acdm, Assoc-voc, 9th, 7th-8th, 12th, Masters, 1st-4th, 10th, Doctorate, 5th-6th, Preschool. 101 | education-num: continuous. 102 | marital-status: Married-civ-spouse, Divorced, Never-married, Separated, Widowed, Married-spouse-absent, Married-AF-spouse. 103 | occupation: Tech-support, Craft-repair, Other-service, Sales, Exec-managerial, Prof-specialty, Handlers-cleaners, Machine-op-inspct, Adm-clerical, Farming-fishing, Transport-moving, Priv-house-serv, Protective-serv, Armed-Forces. 104 | relationship: Wife, Own-child, Husband, Not-in-family, Other-relative, Unmarried. 105 | race: White, Asian-Pac-Islander, Amer-Indian-Eskimo, Other, Black. 106 | sex: Female, Male. 107 | capital-gain: continuous. 108 | capital-loss: continuous. 109 | hours-per-week: continuous. 110 | native-country: United-States, Cambodia, England, Puerto-Rico, Canada, Germany, Outlying-US(Guam-USVI-etc), India, Japan, Greece, South, China, Cuba, Iran, Honduras, Philippines, Italy, Poland, Jamaica, Vietnam, Mexico, Portugal, Ireland, France, Dominican-Republic, Laos, Ecuador, Taiwan, Haiti, Columbia, Hungary, Guatemala, Nicaragua, Scotland, Thailand, Yugoslavia, El-Salvador, Trinadad&Tobago, Peru, Hong, Holand-Netherlands. 111 | -------------------------------------------------------------------------------- /esmm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2021-04-13 14:42 3 | # @Author : WenYi 4 | # @Contact : 1244058349@qq.com 5 | # @Description : script description 6 | 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | 12 | class ESMM(nn.Module): 13 | def __init__(self, user_feature_dict, item_feature_dict, emb_dim=128, hidden_dim=[128, 64], dropouts=[0.5, 0.5], 14 | output_size=1, num_task=2): 15 | """ 16 | esmm model input parameters 17 | :param user_feature_dict: user feature dict include: {feature_name: (feature_unique_num, feature_index)} 18 | :param item_feature_dict: item feature dict include: {feature_name: (feature_unique_num, feature_index)} 19 | :param emb_dim: int, embedding size 20 | :param hidden_dim: list of ctr and ctcvr dnn hidden sizes 21 | :param dropouts: list of ctr and ctcvr dnn drop out probability 22 | :param output_size: int out put size 23 | :param num_task: int default 2 multitask numbers 24 | """ 25 | super(ESMM, self).__init__() 26 | 27 | # check input parameters 28 | if user_feature_dict is None or item_feature_dict is None: 29 | raise Exception("input parameter user_feature_dict and item_feature_dict must be not None") 30 | if isinstance(user_feature_dict, dict) is False or isinstance(item_feature_dict, dict) is False: 31 | raise Exception("input parameter user_feature_dict and item_feature_dict must be dict") 32 | 33 | self.user_feature_dict = user_feature_dict 34 | self.item_feature_dict = item_feature_dict 35 | self.num_task = num_task 36 | 37 | # embedding初始化 38 | user_cate_feature_nums, item_cate_feature_nums = 0, 0 39 | for user_cate, num in self.user_feature_dict.items(): 40 | if num[0] > 1: 41 | user_cate_feature_nums += 1 42 | setattr(self, user_cate, nn.Embedding(num[0], emb_dim)) 43 | for item_cate, num in self.item_feature_dict.items(): 44 | if num[0] > 1: 45 | item_cate_feature_nums += 1 46 | setattr(self, item_cate, nn.Embedding(num[0], emb_dim)) 47 | 48 | # user embedding + item embedding 49 | hidden_size = emb_dim * (user_cate_feature_nums + item_cate_feature_nums) + \ 50 | (len(user_feature_dict) - user_cate_feature_nums) + (len(item_feature_dict) - item_cate_feature_nums) 51 | 52 | # esmm 独立任务的DNN结构 53 | for i in range(self.num_task): 54 | setattr(self, 'task_{}_dnn'.format(i + 1), nn.ModuleList()) 55 | hid_dim = [hidden_size] + hidden_dim 56 | for j in range(len(hid_dim) - 1): 57 | getattr(self, 'task_{}_dnn'.format(i + 1)).add_module('ctr_hidden_{}'.format(j), 58 | nn.Linear(hid_dim[j], hid_dim[j + 1])) 59 | getattr(self, 'task_{}_dnn'.format(i + 1)).add_module('ctr_batchnorm_{}'.format(j), 60 | nn.BatchNorm1d(hid_dim[j + 1])) 61 | getattr(self, 'task_{}_dnn'.format(i + 1)).add_module('ctr_dropout_{}'.format(j), 62 | nn.Dropout(dropouts[j])) 63 | getattr(self, 'task_{}_dnn'.format(i + 1)).add_module('task_last_layer', 64 | nn.Linear(hid_dim[-1], output_size)) 65 | 66 | def forward(self, x): 67 | assert x.size()[1] == len(self.item_feature_dict) + len(self.user_feature_dict) 68 | # embedding 69 | user_embed_list, item_embed_list = list(), list() 70 | for user_feature, num in self.user_feature_dict.items(): 71 | if num[0] > 1: 72 | user_embed_list.append(getattr(self, user_feature)(x[:, num[1]].long())) 73 | else: 74 | user_embed_list.append(x[:, num[1]].unsqueeze(1)) 75 | for item_feature, num in self.item_feature_dict.items(): 76 | if num[0] > 1: 77 | item_embed_list.append(getattr(self, item_feature)(x[:, num[1]].long())) 78 | else: 79 | item_embed_list.append(x[:, num[1]].unsqueeze(1)) 80 | 81 | # embedding 融合 82 | user_embed = torch.cat(user_embed_list, axis=1) 83 | item_embed = torch.cat(item_embed_list, axis=1) 84 | 85 | # hidden layer 86 | hidden = torch.cat([user_embed, item_embed], axis=1).float() 87 | 88 | # task tower 89 | task_outputs = list() 90 | for i in range(self.num_task): 91 | x = hidden 92 | for mod in getattr(self, 'task_{}_dnn'.format(i + 1)): 93 | x = mod(x) 94 | task_outputs.append(x) 95 | 96 | return task_outputs 97 | 98 | 99 | if __name__ == "__main__": 100 | import numpy as np 101 | a = torch.from_numpy(np.array([[1, 2, 4, 2, 0.5, 0.1], 102 | [4, 5, 3, 8, 0.6, 0.43], 103 | [6, 3, 2, 9, 0.12, 0.32], 104 | [9, 1, 1, 1, 0.12, 0.45], 105 | [8, 3, 1, 4, 0.21, 0.67]])) 106 | user_cate_dict = {'user_id': (11, 0), 'user_list': (12, 3), 'user_num': (1, 4)} 107 | item_cate_dict = {'item_id': (8, 1), 'item_cate': (6, 2), 'item_num': (1, 5)} 108 | esmm = ESMM(user_cate_dict, item_cate_dict) 109 | tasks = esmm(a) 110 | print(tasks) 111 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2021-04-19 17:25 3 | # @Author : WenYi 4 | # @Contact : 1244058349@qq.com 5 | # @Description : script description 6 | 7 | 8 | from utils import data_preparation, TrainDataSet 9 | from torch.utils.data import DataLoader 10 | from model_train import train_model 11 | from esmm import ESMM 12 | from mmoe import MMOE 13 | import torch 14 | import torch.nn as nn 15 | 16 | 17 | def main(): 18 | train_data, test_data, user_feature_dict, item_feature_dict = data_preparation() 19 | train_dataset = (train_data.iloc[:, :-2].values, train_data.iloc[:, -2].values, train_data.iloc[:, -1].values) 20 | # val_dataset = (val_data.iloc[:, :-2].values, val_data.iloc[:, -2].values, val_data.iloc[:, -1].values) 21 | test_dataset = (test_data.iloc[:, :-2].values, test_data.iloc[:, -2].values, test_data.iloc[:, -1].values) 22 | train_dataset = TrainDataSet(train_dataset) 23 | # val_dataset = TrainDataSet(val_dataset) 24 | test_dataset = TrainDataSet(test_dataset) 25 | 26 | # dataloader 27 | train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True) 28 | # val_dataloader = DataLoader(val_dataset, batch_size=64, shuffle=False) 29 | test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=False) 30 | 31 | # pytorch优化参数 32 | learn_rate = 0.01 33 | bce_loss = nn.BCEWithLogitsLoss() 34 | early_stop = 3 35 | 36 | # train model 37 | # esmm Epoch 17 val loss is 1.164, income auc is 0.875 and marry auc is 0.953 38 | esmm = ESMM(user_feature_dict, item_feature_dict, emb_dim=64) 39 | optimizer = torch.optim.Adam(esmm.parameters(), lr=learn_rate) 40 | train_model(esmm, train_dataloader, test_dataloader, 20, bce_loss, optimizer, 'model/model_esmm_{}', early_stop) 41 | 42 | # mmoe 43 | mmoe = MMOE(user_feature_dict, item_feature_dict, emb_dim=64) 44 | optimizer = torch.optim.Adam(mmoe.parameters(), lr=learn_rate) 45 | train_model(mmoe, train_dataloader, test_dataloader, 20, bce_loss, optimizer, 'model/model_mmoe_{}', early_stop) 46 | 47 | 48 | if __name__ == "__main__": 49 | main() 50 | -------------------------------------------------------------------------------- /mmoe.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2021-04-19 12:12 3 | # @Author : WenYi 4 | # @Contact : 1244058349@qq.com 5 | # @Description : script description 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | 11 | class MMOE(nn.Module): 12 | """ 13 | MMOE for CTCVR problem 14 | """ 15 | def __init__(self, user_feature_dict, item_feature_dict, emb_dim=128, n_expert=3, mmoe_hidden_dim=128, 16 | hidden_dim=[128, 64], dropouts=[0.5, 0.5], output_size=1, expert_activation=None, num_task=2): 17 | """ 18 | MMOE model input parameters 19 | :param user_feature_dict: user feature dict include: {feature_name: (feature_unique_num, feature_index)} 20 | :param item_feature_dict: item feature dict include: {feature_name: (feature_unique_num, feature_index)} 21 | :param emb_dim: int embedding dimension 22 | :param n_expert: int number of experts in mmoe 23 | :param mmoe_hidden_dim: mmoe layer input dimension 24 | :param hidden_dim: list task tower hidden dimension 25 | :param dropouts: list of task dnn drop out probability 26 | :param output_size: int task output size 27 | :param expert_activation: activation function like 'relu' or 'sigmoid' 28 | :param num_task: int default 2 multitask numbers 29 | """ 30 | super(MMOE, self).__init__() 31 | # check input parameters 32 | if user_feature_dict is None or item_feature_dict is None: 33 | raise Exception("input parameter user_feature_dict and item_feature_dict must be not None") 34 | if isinstance(user_feature_dict, dict) is False or isinstance(item_feature_dict, dict) is False: 35 | raise Exception("input parameter user_feature_dict and item_feature_dict must be dict") 36 | 37 | self.user_feature_dict = user_feature_dict 38 | self.item_feature_dict = item_feature_dict 39 | self.expert_activation = expert_activation 40 | self.num_task = num_task 41 | 42 | # embedding初始化 43 | user_cate_feature_nums, item_cate_feature_nums = 0, 0 44 | for user_cate, num in self.user_feature_dict.items(): 45 | if num[0] > 1: 46 | user_cate_feature_nums += 1 47 | setattr(self, user_cate, nn.Embedding(num[0], emb_dim)) 48 | for item_cate, num in self.item_feature_dict.items(): 49 | if num[0] > 1: 50 | item_cate_feature_nums += 1 51 | setattr(self, item_cate, nn.Embedding(num[0], emb_dim)) 52 | 53 | # user embedding + item embedding 54 | hidden_size = emb_dim * (user_cate_feature_nums + item_cate_feature_nums) + \ 55 | (len(self.user_feature_dict) - user_cate_feature_nums) + ( 56 | len(self.item_feature_dict) - item_cate_feature_nums) 57 | 58 | # experts 59 | self.experts = torch.nn.Parameter(torch.rand(hidden_size, mmoe_hidden_dim, n_expert), requires_grad=True) 60 | self.experts.data.normal_(0, 1) 61 | self.experts_bias = torch.nn.Parameter(torch.rand(mmoe_hidden_dim, n_expert), requires_grad=True) 62 | # gates 63 | self.gates = [torch.nn.Parameter(torch.rand(hidden_size, n_expert), requires_grad=True) for _ in range(num_task)] 64 | for gate in self.gates: 65 | gate.data.normal_(0, 1) 66 | self.gates_bias = [torch.nn.Parameter(torch.rand(n_expert), requires_grad=True) for _ in range(num_task)] 67 | 68 | # esmm ctr和ctcvr独立任务的DNN结构 69 | for i in range(self.num_task): 70 | setattr(self, 'task_{}_dnn'.format(i+1), nn.ModuleList()) 71 | hid_dim = [mmoe_hidden_dim] + hidden_dim 72 | for j in range(len(hid_dim) - 1): 73 | getattr(self, 'task_{}_dnn'.format(i+1)).add_module('ctr_hidden_{}'.format(j), nn.Linear(hid_dim[j], hid_dim[j + 1])) 74 | getattr(self, 'task_{}_dnn'.format(i+1)).add_module('ctr_batchnorm_{}'.format(j), nn.BatchNorm1d(hid_dim[j + 1])) 75 | getattr(self, 'task_{}_dnn'.format(i+1)).add_module('ctr_dropout_{}'.format(j), nn.Dropout(dropouts[j])) 76 | getattr(self, 'task_{}_dnn'.format(i+1)).add_module('task_last_layer', nn.Linear(hid_dim[-1], output_size)) 77 | 78 | def forward(self, x): 79 | assert x.size()[1] == len(self.item_feature_dict) + len(self.user_feature_dict) 80 | # embedding 81 | user_embed_list, item_embed_list = list(), list() 82 | for user_feature, num in self.user_feature_dict.items(): 83 | if num[0] > 1: 84 | user_embed_list.append(getattr(self, user_feature)(x[:, num[1]].long())) 85 | else: 86 | user_embed_list.append(x[:, num[1]].unsqueeze(1)) 87 | for item_feature, num in self.item_feature_dict.items(): 88 | if num[0] > 1: 89 | item_embed_list.append(getattr(self, item_feature)(x[:, num[1]].long())) 90 | else: 91 | item_embed_list.append(x[:, num[1]].unsqueeze(1)) 92 | 93 | # embedding 融合 94 | user_embed = torch.cat(user_embed_list, axis=1) 95 | item_embed = torch.cat(item_embed_list, axis=1) 96 | 97 | # hidden layer 98 | hidden = torch.cat([user_embed, item_embed], axis=1).float() # batch * hidden_size 99 | 100 | # mmoe 101 | experts_out = torch.einsum('ij, jkl -> ikl', hidden, self.experts) # batch * mmoe_hidden_size * num_experts 102 | experts_out += self.experts_bias 103 | if self.expert_activation is not None: 104 | experts_out = self.expert_activation(experts_out) 105 | 106 | gates_out = list() 107 | for idx, gate in enumerate(self.gates): 108 | gate_out = torch.einsum('ab, bc -> ac', hidden, gate) # batch * num_experts 109 | if self.gates_bias: 110 | gate_out += self.gates_bias[idx] 111 | gate_out = nn.Softmax(dim=-1)(gate_out) 112 | gates_out.append(gate_out) 113 | 114 | outs = list() 115 | for gate_output in gates_out: 116 | expanded_gate_output = torch.unsqueeze(gate_output, 1) # batch * 1 * num_experts 117 | weighted_expert_output = experts_out * expanded_gate_output.expand_as(experts_out) # batch * mmoe_hidden_size * num_experts 118 | outs.append(torch.sum(weighted_expert_output, 2)) # batch * mmoe_hidden_size 119 | 120 | # task tower 121 | task_outputs = list() 122 | for i in range(self.num_task): 123 | x = outs[i] 124 | for mod in getattr(self, 'task_{}_dnn'.format(i+1)): 125 | x = mod(x) 126 | task_outputs.append(x) 127 | 128 | return task_outputs 129 | 130 | 131 | if __name__ == "__main__": 132 | import numpy as np 133 | 134 | a = torch.from_numpy(np.array([[1, 2, 4, 2, 0.5, 0.1], 135 | [4, 5, 3, 8, 0.6, 0.43], 136 | [6, 3, 2, 9, 0.12, 0.32], 137 | [9, 1, 1, 1, 0.12, 0.45], 138 | [8, 3, 1, 4, 0.21, 0.67]])) 139 | user_cate_dict = {'user_id': (11, 0), 'user_list': (12, 3), 'user_num': (1, 4)} 140 | item_cate_dict = {'item_id': (8, 1), 'item_cate': (6, 2), 'item_num': (1, 5)} 141 | mmoe = MMOE(user_cate_dict, item_cate_dict) 142 | outs = mmoe(a) 143 | print(outs) 144 | -------------------------------------------------------------------------------- /model/model_esmm_1: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/busesese/MultiTaskModel/69e7b4468bcc654b8066bf51dc9c8e34913ebd05/model/model_esmm_1 -------------------------------------------------------------------------------- /model/model_mmoe_1: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/busesese/MultiTaskModel/69e7b4468bcc654b8066bf51dc9c8e34913ebd05/model/model_mmoe_1 -------------------------------------------------------------------------------- /model_train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2021-04-19 17:10 3 | # @Author : WenYi 4 | # @Contact : 1244058349@qq.com 5 | # @Description : model train function 6 | 7 | import torch 8 | from tqdm import tqdm 9 | from sklearn.metrics import roc_auc_score 10 | 11 | def train_model(model, train_loader, val_loader, epoch, loss_function, optimizer, path, early_stop): 12 | """ 13 | pytorch model train function 14 | :param model: pytorch model 15 | :param train_loader: dataloader, train data loader 16 | :param val_loader: dataloader, val data loader 17 | :param epoch: int, number of iters 18 | :param loss_function: loss function of train model 19 | :param optimizer: pytorch optimizer 20 | :param path: save path 21 | :param early_stop: int, early stop number 22 | :return: None 23 | """ 24 | # use GPU 25 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 26 | model.to(device) 27 | 28 | # 多少步内验证集的loss没有变小就提前停止 29 | patience, eval_loss = 0, 0 30 | 31 | # train 32 | for i in range(epoch): 33 | y_train_income_true = [] 34 | y_train_income_predict = [] 35 | y_train_marry_true = [] 36 | y_train_marry_predict = [] 37 | total_loss, count = 0, 0 38 | for idx, (x, y1, y2) in tqdm(enumerate(train_loader), total=len(train_loader)): 39 | x, y1, y2 = x.to(device), y1.to(device), y2.to(device) 40 | predict = model(x) 41 | y_train_income_true += list(y1.squeeze().cpu().numpy()) 42 | y_train_marry_true += list(y2.squeeze().cpu().numpy()) 43 | y_train_income_predict += list(predict[0].squeeze().cpu().detach().numpy()) 44 | y_train_marry_predict += list(predict[1].squeeze().cpu().detach().numpy()) 45 | loss_1 = loss_function(predict[0], y1.unsqueeze(1).float()) 46 | loss_2 = loss_function(predict[1], y2.unsqueeze(1).float()) 47 | loss = loss_1 + loss_2 48 | optimizer.zero_grad() 49 | loss.backward() 50 | optimizer.step() 51 | total_loss += float(loss) 52 | count += 1 53 | torch.save(model, path.format(i + 1)) 54 | income_auc = roc_auc_score(y_train_income_true, y_train_income_predict) 55 | marry_auc = roc_auc_score(y_train_marry_true, y_train_marry_predict) 56 | print("Epoch %d train loss is %.3f, income auc is %.3f and marry auc is %.3f" % (i + 1, total_loss / count, 57 | income_auc, marry_auc)) 58 | 59 | # 验证 60 | total_eval_loss = 0 61 | model.eval() 62 | count_eval = 0 63 | y_val_income_true = [] 64 | y_val_marry_true = [] 65 | y_val_income_predict = [] 66 | y_val_marry_predict = [] 67 | for idx, (x, y1, y2) in tqdm(enumerate(val_loader), total=len(val_loader)): 68 | x, y1, y2 = x.to(device), y1.to(device), y2.to(device) 69 | predict = model(x) 70 | y_val_income_true += list(y1.squeeze().cpu().numpy()) 71 | y_val_marry_true += list(y2.squeeze().cpu().numpy()) 72 | y_val_income_predict += list(predict[0].squeeze().cpu().detach().numpy()) 73 | y_val_marry_predict += list(predict[1].squeeze().cpu().detach().numpy()) 74 | loss_1 = loss_function(predict[0], y1.unsqueeze(1).float()) 75 | loss_2 = loss_function(predict[1], y2.unsqueeze(1).float()) 76 | loss = loss_1 + loss_2 77 | total_eval_loss += float(loss) 78 | count_eval += 1 79 | income_auc = roc_auc_score(y_val_income_true, y_val_income_predict) 80 | marry_auc = roc_auc_score(y_val_marry_true, y_val_marry_predict) 81 | print("Epoch %d val loss is %.3f, income auc is %.3f and marry auc is %.3f" % (i + 1, 82 | total_eval_loss / count_eval, 83 | income_auc, marry_auc)) 84 | 85 | # earl stopping 86 | if i == 0: 87 | eval_loss = total_eval_loss / count_eval 88 | else: 89 | if total_eval_loss / count_eval < eval_loss: 90 | eval_loss = total_eval_loss / count_eval 91 | else: 92 | if patience < early_stop: 93 | patience += 1 94 | else: 95 | print("val loss is not decrease in %d epoch and break training" % patience) 96 | break 97 | -------------------------------------------------------------------------------- /pic/ESMM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/busesese/MultiTaskModel/69e7b4468bcc654b8066bf51dc9c8e34913ebd05/pic/ESMM.png -------------------------------------------------------------------------------- /pic/MMOE.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/busesese/MultiTaskModel/69e7b4468bcc654b8066bf51dc9c8e34913ebd05/pic/MMOE.png -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2021-04-20 10:36 3 | # @Author : WenYi 4 | # @Contact : 1244058349@qq.com 5 | # @Description : script description 6 | 7 | 8 | import pandas as pd 9 | from sklearn.preprocessing import LabelEncoder, MinMaxScaler 10 | from sklearn.model_selection import train_test_split 11 | from torch.utils.data import Dataset, DataLoader 12 | 13 | 14 | # data process 15 | def data_preparation(): 16 | # The column names are from 17 | column_names = ['age', 'workclass', 'fnlwgt', 'education', 'education_num', 'marital_status', 'occupation', 18 | 'relationship', 'race', 'sex', 'capital_gain', 'capital_loss', 'hours_per_week', 'native_country', 19 | 'income_50k'] 20 | 21 | # Load the dataset in Pandas 22 | train_df = pd.read_csv( 23 | 'data/adult.data', 24 | delimiter=',', 25 | header=None, 26 | index_col=None, 27 | names=column_names 28 | ) 29 | other_df = pd.read_csv( 30 | 'data/adult.test', 31 | delimiter=',', 32 | header=None, 33 | index_col=None, 34 | names=column_names 35 | ) 36 | 37 | train_df['tag'] = 1 38 | other_df['tag'] = 0 39 | other_df.dropna(inplace=True) 40 | other_df['income_50k'] = other_df['income_50k'].apply(lambda x: x[:-1]) 41 | data = pd.concat([train_df, other_df]) 42 | data.dropna(inplace=True) 43 | # First group of tasks according to the paper 44 | label_columns = ['income_50k', 'marital_status'] 45 | 46 | # categorical columns 47 | categorical_columns = ['workclass', 'education', 'occupation', 'relationship', 'race', 'sex', 'native_country'] 48 | for col in label_columns: 49 | if col == 'income_50k': 50 | data[col] = data[col].apply(lambda x: 0 if x == ' <=50K' else 1) 51 | else: 52 | data[col] = data[col].apply(lambda x: 0 if x == ' Never-married' else 1) 53 | 54 | # feature engine 55 | for col in column_names: 56 | if col not in label_columns + ['tag']: 57 | if col in categorical_columns: 58 | le = LabelEncoder() 59 | data[col] = le.fit_transform(data[col]) 60 | else: 61 | mm = MinMaxScaler() 62 | data[col] = mm.fit_transform(data[[col]]).reshape(-1) 63 | data = data[['age', 'workclass', 'fnlwgt', 'education', 'education_num', 'occupation', 64 | 'relationship', 'race', 'sex', 'capital_gain', 'capital_loss', 'hours_per_week', 'native_country', 65 | 'income_50k', 'marital_status', 'tag']] 66 | 67 | # user feature, item feature 68 | user_feature_dict, item_feature_dict = dict(), dict() 69 | for idx, col in enumerate(data.columns): 70 | if col not in label_columns + ['tag']: 71 | if idx < 7: 72 | if col in categorical_columns: 73 | user_feature_dict[col] = (len(data[col].unique())+1, idx) 74 | else: 75 | user_feature_dict[col] = (1, idx) 76 | else: 77 | if col in categorical_columns: 78 | item_feature_dict[col] = (len(data[col].unique())+1, idx) 79 | else: 80 | item_feature_dict[col] = (1, idx) 81 | 82 | # Split the other dataset into 1:1 validation to test according to the paper 83 | train_data, test_data = data[data['tag'] == 1], data[data['tag'] == 0] 84 | train_data.drop('tag', axis=1, inplace=True) 85 | test_data.drop('tag', axis=1, inplace=True) 86 | 87 | # val data 88 | # train_data, val_data = train_test_split(train_data, test_size=0.5, random_state=2021) 89 | return train_data, test_data, user_feature_dict, item_feature_dict 90 | 91 | 92 | class TrainDataSet(Dataset): 93 | def __init__(self, data): 94 | self.feature = data[0] 95 | self.label1 = data[1] 96 | self.label2 = data[2] 97 | 98 | def __getitem__(self, index): 99 | feature = self.feature[index] 100 | label1 = self.label1[index] 101 | label2 = self.label2[index] 102 | return feature, label1, label2 103 | 104 | def __len__(self): 105 | return len(self.feature) 106 | --------------------------------------------------------------------------------