├── .gitignore ├── FRCTR ├── basic │ ├── __init__.py │ ├── activation.py │ ├── features.py │ ├── initializers.py │ └── metrics.py ├── common │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── attention_layer.cpython-36.pyc │ │ ├── basic_layers.cpython-36.pyc │ │ ├── interaction_layer.cpython-36.pyc │ │ └── layers.cpython-36.pyc │ ├── attention_layer.py │ ├── basic_layers.py │ ├── interaction_layer.py │ ├── layers.py │ ├── prediction_layer.py │ └── refinement_layer.py ├── data │ ├── Criteo │ │ ├── CriteoDataLoader.py │ │ └── __pycache__ │ │ │ └── CriteoDataLoader.cpython-36.pyc │ ├── Frappe │ │ ├── FrappeDataLoader.py │ │ ├── __pycache__ │ │ │ └── FrappeDataLoader.cpython-36.pyc │ │ ├── frappe.test.libfm │ │ ├── frappe.train.libfm │ │ └── frappe.validation.libfm │ ├── __init__.py │ └── __pycache__ │ │ └── __init__.cpython-36.pyc ├── model_zoo │ ├── __init__.py │ ├── afm.py │ ├── afnp.py │ ├── autoint.py │ ├── dcap.py │ ├── dcn.py │ ├── dcnv2.py │ ├── deepfm.py │ ├── fed.py │ ├── fibinet.py │ ├── fint.py │ ├── fm.py │ ├── fmfm.py │ ├── fnn.py │ ├── fwfm.py │ ├── lr.py │ ├── nfm.py │ ├── pnn.py │ └── xdeepfm.py ├── module_zoo │ ├── __init__.py │ ├── contextnet.py │ ├── dfen.py │ ├── drm.py │ ├── fal.py │ ├── fen.py │ ├── frnet.py │ ├── fwn.py │ ├── gatenet.py │ ├── gfrl.py │ ├── selfatt.py │ ├── senet.py │ └── skip.py └── utils │ ├── __init__.py │ ├── auc.py │ ├── earlystoping.py │ └── util.py ├── README.md ├── RefineCTR.png ├── evaluation ├── figure │ ├── RefineCTR.png │ ├── deepfm_auc.jpg │ ├── deepfm_ll.jpg │ ├── refineCTR framework.png │ └── refinectr structure.png └── mains │ ├── main_criteo_base.py │ └── main_frappe_base.py └── refineCTR framework.png /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | -------------------------------------------------------------------------------- /FRCTR/basic/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | """ 3 | @project: RefineCTR 4 | """ 5 | -------------------------------------------------------------------------------- /FRCTR/basic/activation.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | """ 3 | @project:RefineCTR 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | class Dice(nn.Module): 10 | def __init__(self, epsilon=1e-3): 11 | super(Dice, self).__init__() 12 | self.epsilon = epsilon 13 | self.alpha = nn.Parameter(torch.randn(1)) 14 | 15 | def forward(self, x: torch.Tensor): 16 | avg = x.mean(dim=1) # N 17 | avg = avg.unsqueeze(dim=1) # N * 1 18 | var = torch.pow(x - avg, 2) + self.epsilon # N * num_neurons 19 | var = var.sum(dim=1).unsqueeze(dim=1) # N * 1 20 | 21 | ps = (x - avg) / torch.sqrt(var) # N * 1 22 | 23 | ps = nn.Sigmoid()(ps) # N * 1 24 | return ps * x + (1 - ps) * self.alpha * x 25 | 26 | 27 | def activation_layer(act_name): 28 | if isinstance(act_name, str): 29 | if act_name.lower() == 'sigmoid': 30 | act_layer = nn.Sigmoid() 31 | elif act_name.lower() == 'relu': 32 | act_layer = nn.ReLU(inplace=True) 33 | elif act_name.lower() == 'dice': 34 | act_layer = Dice() 35 | elif act_name.lower() == 'prelu': 36 | act_layer = nn.PReLU() 37 | elif act_name.lower() == "softmax": 38 | act_layer = nn.Softmax(dim=1) 39 | elif issubclass(act_name, nn.Module): 40 | act_layer = act_name() 41 | else: 42 | raise NotImplementedError 43 | return act_layer 44 | -------------------------------------------------------------------------------- /FRCTR/basic/features.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | """ 3 | @project:RefineCTR 4 | """ 5 | -------------------------------------------------------------------------------- /FRCTR/basic/initializers.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | """ 3 | @project:RefineCTR 4 | """ 5 | 6 | import torch 7 | 8 | 9 | class RandomNormal(object): 10 | 11 | def __init__(self, mean=0.0, std=1.0): 12 | self.mean = mean 13 | self.std = std 14 | 15 | def __call__(self, vocab_size, embed_dim): 16 | embed = torch.nn.Embedding(vocab_size, embed_dim) 17 | torch.nn.init.normal_(embed.weight, self.mean, self.std) 18 | return embed 19 | 20 | 21 | class RandomUniform(object): 22 | def __init__(self, minval=0.0, maxval=1.0): 23 | self.minval = minval 24 | self.maxval = maxval 25 | 26 | def __call__(self, vocab_size, embed_dim): 27 | embed = torch.nn.Embedding(vocab_size, embed_dim) 28 | torch.nn.init.uniform_(embed.weight, self.minval, self.maxval) 29 | return embed 30 | 31 | 32 | class XavierNormal(object): 33 | 34 | def __init__(self, gain=1.0): 35 | self.gain = gain 36 | 37 | def __call__(self, vocab_size, embed_dim): 38 | embed = torch.nn.Embedding(vocab_size, embed_dim) 39 | torch.nn.init.xavier_normal_(embed.weight, self.gain) 40 | return embed 41 | 42 | 43 | class XavierUniform(object): 44 | def __init__(self, gain=1.0): 45 | self.gain = gain 46 | 47 | def __call__(self, vocab_size, embed_dim): 48 | embed = torch.nn.Embedding(vocab_size, embed_dim) 49 | torch.nn.init.xavier_uniform_(embed.weight, self.gain) 50 | return embed 51 | 52 | 53 | class Pretrained(object): 54 | def __init__(self, embedding_weight, freeze=True): 55 | self.embedding_weight = torch.FloatTensor(embedding_weight) 56 | self.freeze = freeze 57 | 58 | def __call__(self, vocab_size, embed_dim): 59 | assert vocab_size == self.embedding_weight.shape[0] and embed_dim == self.embedding_weight.shape[1] 60 | embed = torch.nn.Embedding.from_pretrained(self.embedding_weight, freeze=self.freeze) 61 | return embed 62 | -------------------------------------------------------------------------------- /FRCTR/basic/metrics.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | """ 3 | @project:RefineCTR 4 | """ 5 | 6 | from sklearn.metrics import log_loss, roc_auc_score 7 | 8 | 9 | -------------------------------------------------------------------------------- /FRCTR/common/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | """ 3 | @project: RefineCTR 4 | """ 5 | 6 | import sys 7 | sys.path.append("../") 8 | from .attention_layer import * 9 | from .layers import * 10 | from .interaction_layer import * 11 | from .basic_layers import * -------------------------------------------------------------------------------- /FRCTR/common/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codectr/RefineCTR/edd3965120f54cf4e799b179d4383fda2d7ef65e/FRCTR/common/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /FRCTR/common/__pycache__/attention_layer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codectr/RefineCTR/edd3965120f54cf4e799b179d4383fda2d7ef65e/FRCTR/common/__pycache__/attention_layer.cpython-36.pyc -------------------------------------------------------------------------------- /FRCTR/common/__pycache__/basic_layers.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codectr/RefineCTR/edd3965120f54cf4e799b179d4383fda2d7ef65e/FRCTR/common/__pycache__/basic_layers.cpython-36.pyc -------------------------------------------------------------------------------- /FRCTR/common/__pycache__/interaction_layer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codectr/RefineCTR/edd3965120f54cf4e799b179d4383fda2d7ef65e/FRCTR/common/__pycache__/interaction_layer.cpython-36.pyc -------------------------------------------------------------------------------- /FRCTR/common/__pycache__/layers.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codectr/RefineCTR/edd3965120f54cf4e799b179d4383fda2d7ef65e/FRCTR/common/__pycache__/layers.cpython-36.pyc -------------------------------------------------------------------------------- /FRCTR/common/attention_layer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | """ 3 | @project: RefineCTR 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import numpy as np 10 | 11 | class AttentionLayer(nn.Module): 12 | def __init__(self, dim=32): 13 | super().__init__() 14 | self.dim = dim 15 | self.q_layer = nn.Linear(dim, dim, bias=False) 16 | self.k_layer = nn.Linear(dim, dim, bias=False) 17 | self.v_layer = nn.Linear(dim, dim, bias=False) 18 | self.softmax = nn.Softmax(dim=-1) 19 | 20 | def forward(self, x): 21 | Q = self.q_layer(x) 22 | K = self.k_layer(x) 23 | V = self.v_layer(x) 24 | a = torch.sum(torch.mul(Q, K), -1) / torch.sqrt(torch.tensor(self.dim)) 25 | a = self.softmax(a) 26 | outputs = torch.sum(torch.mul(torch.unsqueeze(a, -1), V), dim=1) 27 | return outputs 28 | 29 | 30 | class FieldAttentionModule(nn.Module): 31 | def __init__(self, embed_dim): 32 | super(FieldAttentionModule, self).__init__() 33 | self.trans_Q = nn.Linear(embed_dim, embed_dim) 34 | self.trans_K = nn.Linear(embed_dim, embed_dim) 35 | self.trans_V = nn.Linear(embed_dim, embed_dim) 36 | 37 | def forward(self, x, scale=None, mask=None): 38 | Q = self.trans_Q(x) 39 | K = self.trans_K(x) 40 | V = self.trans_V(x) 41 | 42 | attention = torch.matmul(Q, K.permute(0, 2, 1)) 43 | if scale: 44 | attention = attention * scale 45 | if mask: 46 | attention = attention.masked_fill_(mask == 0, -1e9) 47 | attention = F.softmax(attention, dim=-1) 48 | context = torch.matmul(attention, V) 49 | 50 | return context 51 | 52 | 53 | class Attention(nn.Module): 54 | def __init__(self, method='dot', hidden_size=None): 55 | super(Attention, self).__init__() 56 | self.method = method 57 | if self.method != 'dot': 58 | self.hidden_size = hidden_size 59 | if self.method == 'general': 60 | self.W = nn.Linear(hidden_size, hidden_size) 61 | elif self.method == 'concat': 62 | self.W = nn.Linear(self.hidden_size * 2, hidden_size) 63 | self.v = nn.Parameter(torch.rand(1, hidden_size)) 64 | nn.init.xavier_normal_(self.v.data) 65 | 66 | def forward(self, query, key, value, mask=None, dropout=0): 67 | if self.method == 'general': 68 | scores = self.general(query, key) 69 | elif self.method == 'concat': 70 | scores = self.concat(query, key) 71 | else: 72 | scores = self.dot(query, key) 73 | 74 | if mask is not None: 75 | scores = scores.masked_fill(mask == 0, -1e9) 76 | p_attn = F.softmax(scores, dim=-1) 77 | if not dropout: 78 | p_attn = F.dropout(p_attn, dropout) 79 | 80 | return torch.matmul(p_attn, value), p_attn 81 | 82 | def dot(self, query, key): 83 | scores = torch.matmul(query, key.transpose(-2, -1)) 84 | return scores 85 | 86 | def general(self, query, key): 87 | scores = torch.matmul(self.W(query), key.transpose(-2, -1)) 88 | return scores 89 | 90 | def concat(self, query, key): 91 | scores = torch.cat((query.expand(-1, key.size(1), -1), key), dim=2) 92 | scores = self.W(scores) 93 | scores = F.tanh(scores) 94 | scores = torch.matmul(scores, self.v.t()).transpose(-2, -1) 95 | return scores 96 | 97 | 98 | class GeneralAttention(nn.Module): 99 | def __init__(self, embed_dim, conv_size=0): 100 | super(GeneralAttention, self).__init__() 101 | if conv_size == 0: 102 | conv_size = embed_dim 103 | # self.attention = torch.nn.Linear(embed_dim, embed_dim) 104 | self.attention = torch.nn.Linear(embed_dim, conv_size) 105 | self.projection = torch.nn.Linear(conv_size, 1) 106 | # self.projection = torch.nn.Linear(embed_dim, 1) 107 | 108 | def forward(self, key, dim=1): 109 | attn_scores = F.relu(self.attention(key)) 110 | attn_scores = F.softmax(self.projection(attn_scores), dim=dim) 111 | attn_output = torch.sum(attn_scores * key, dim=dim) # B,e 112 | return attn_output, attn_scores 113 | 114 | 115 | class ScaledDotProductAttention(nn.Module): 116 | def __init__(self, d_model, d_k, d_v, h, dropout=.1): 117 | super(ScaledDotProductAttention, self).__init__() 118 | self.fc_q = nn.Linear(d_model, h * d_k) 119 | self.fc_k = nn.Linear(d_model, h * d_k) 120 | self.fc_v = nn.Linear(d_model, h * d_v) 121 | self.fc_o = nn.Linear(h * d_v, d_model) 122 | self.dropout = nn.Dropout(dropout) 123 | 124 | self.d_model = d_model 125 | self.d_k = d_k 126 | self.d_v = d_v 127 | self.h = h 128 | 129 | self.init_weights() 130 | 131 | def init_weights(self): 132 | for m in self.modules(): 133 | if isinstance(m, nn.Conv2d): 134 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 135 | if m.bias is not None: 136 | nn.init.constant_(m.bias, 0) 137 | elif isinstance(m, nn.BatchNorm2d): 138 | nn.init.constant_(m.weight, 1) 139 | nn.init.constant_(m.bias, 0) 140 | elif isinstance(m, nn.Linear): 141 | nn.init.normal_(m.weight, std=0.001) 142 | if m.bias is not None: 143 | nn.init.constant_(m.bias, 0) 144 | 145 | def forward(self, queries, keys, values, attention_mask=None, attention_weights=None): 146 | 147 | b_s, nq = queries.shape[:2] 148 | nk = keys.shape[1] 149 | 150 | q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3) # (b_s, h, nq, d_k) 151 | k = self.fc_k(keys).view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1) # (b_s, h, d_k, nk) 152 | v = self.fc_v(values).view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3) # (b_s, h, nk, d_v) 153 | # scale = (key.size(-1) // num_heads) ** -0.5 154 | att = torch.matmul(q, k) / np.sqrt(self.d_k) # (b_s, h, nq, nk) 155 | if attention_weights is not None: 156 | att = att * attention_weights 157 | if attention_mask is not None: 158 | att = att.masked_fill(attention_mask, -np.inf) 159 | att = torch.softmax(att, -1) 160 | att = self.dropout(att) 161 | 162 | out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v) # (b_s, nq, h*d_v) 163 | out = self.fc_o(out) # (b_s, nq, d_model) 164 | return out 165 | 166 | class ScaleDotProductAttention(nn.Module): 167 | 168 | def forward(self, q, k, v, scale=None, attn_mask=None): 169 | attention = torch.bmm(q, k.transpose(1, 2)) 170 | if scale: 171 | attention = attention * scale 172 | if attn_mask: 173 | attention = attention.masked_fill(attn_mask, -np.inf) 174 | attention = torch.softmax(attention, dim=2) 175 | 176 | attention = torch.dropout(attention, p=0.0, train=self.training) 177 | context = torch.bmm(attention, v) 178 | return context, attention 179 | 180 | 181 | class MultiHeadAttention(nn.Module): 182 | def __init__(self, model_dim=20, dk=32, num_heads=16, out_dim=32, use_res = True): 183 | super(MultiHeadAttention, self).__init__() 184 | self.use_res = use_res 185 | self.dim_per_head = dk 186 | self.num_heads = num_heads 187 | 188 | self.linear_k = nn.Linear(model_dim, self.dim_per_head * num_heads) 189 | self.linear_v = nn.Linear(model_dim, self.dim_per_head * num_heads) 190 | self.linear_q = nn.Linear(model_dim, self.dim_per_head * num_heads) 191 | 192 | self.outputw = torch.nn.Linear(self.dim_per_head * num_heads, out_dim) 193 | if self.use_res: 194 | # self.linear_residual = nn.Linear(model_dim, self.dim_per_head * num_heads) 195 | self.linear_residual = nn.Linear(model_dim, out_dim) 196 | # self.dot_product_attention = DotAttention() 197 | # self.linear_final = nn.Linear(model_dim, model_dim) 198 | # self.linear_residual = nn.Linear(model_dim, self.dim_per_head * num_heads) 199 | # self.layer_norm = nn.LayerNorm(model_dim) # LayerNorm 归一化。 200 | 201 | def _dot_product_attention(self, q, k, v, scale=None, attn_mask=None): 202 | attention = torch.bmm(q, k.transpose(1, 2)) 203 | if scale: 204 | attention = attention * scale 205 | if attn_mask: 206 | attention = attention.masked_fill(attn_mask, -np.inf) 207 | # score = softmax(QK^T / (d_k ** 0.5)) 208 | attention = torch.softmax(attention, dim=2) 209 | 210 | attention = torch.dropout(attention, p=0.0, train=self.training) 211 | # out = score * V 212 | context = torch.bmm(attention, v) 213 | return context, attention 214 | 215 | def forward(self, query, key, value, attn_mask=None): 216 | batch_size = key.size(0) 217 | 218 | key = self.linear_k(key) # K = UWk [B, 10, 256*16] 219 | value = self.linear_v(value) # Q = UWv [B, 10, 256*16] 220 | query = self.linear_q(query) # V = UWq [B, 10, 256*16] 221 | 222 | # [B, 10, 256*16] =》 [B*16, 10, 256] 223 | key = key.view(batch_size * self.num_heads, -1, self.dim_per_head) 224 | value = value.view(batch_size * self.num_heads, -1, self.dim_per_head) 225 | query = query.view(batch_size * self.num_heads, -1, self.dim_per_head) 226 | 227 | if attn_mask: 228 | attn_mask = attn_mask.unsqueeze(1).repeat(self.num_heads*batch_size, query.size(1), 1) 229 | 230 | scale = (key.size(-1) // self.num_heads) ** -0.5 231 | # QK^T/(dk**0.5) * V 232 | context, attention = self._dot_product_attention(query, key, value, scale, attn_mask) # [B*16, 10, 256] 233 | 234 | context = context.view(batch_size, -1, self.dim_per_head * self.num_heads) # [B, 10, 256*16] 235 | context = self.outputw(context) # B, F, out_dim 236 | 237 | if self.use_res: 238 | context += self.linear_residual(query) # B, F, out_dim 239 | 240 | return context, attention 241 | 242 | class MultiHeadAttention2(nn.Module): 243 | 244 | def __init__(self, query_dim, key_dim, num_units, num_heads): 245 | 246 | super().__init__() 247 | self.num_units = num_units 248 | self.num_heads = num_heads 249 | self.key_dim = key_dim 250 | 251 | self.W_query = nn.Linear(in_features=query_dim, out_features=num_units, bias=False) 252 | self.W_key = nn.Linear(in_features=key_dim, out_features=num_units, bias=False) 253 | self.W_value = nn.Linear(in_features=key_dim, out_features=num_units, bias=False) 254 | 255 | def forward(self, query, key, mask=None): 256 | querys = self.W_query(query) # [N, T_q, num_units] 257 | keys = self.W_key(key) # [N, T_k, num_units] 258 | values = self.W_value(key) 259 | 260 | split_size = self.num_units // self.num_heads 261 | querys = torch.stack(torch.split(querys, split_size, dim=2), dim=0) # [h, N, T_q, num_units/h] 262 | keys = torch.stack(torch.split(keys, split_size, dim=2), dim=0) # [h, N, T_k, num_units/h] 263 | values = torch.stack(torch.split(values, split_size, dim=2), dim=0) # [h, N, T_k, num_units/h] 264 | 265 | scores = torch.matmul(querys, keys.transpose(2, 3)) # [h, N, T_q, T_k] 266 | scores = scores / (self.key_dim ** 0.5) 267 | 268 | if mask: 269 | mask = mask.unsqueeze(1).unsqueeze(0).repeat(self.num_heads,1,querys.shape[2],1) 270 | scores = scores.masked_fill(mask, -np.inf) 271 | scores = F.softmax(scores, dim=3) 272 | 273 | out = torch.matmul(scores, values) # [h, N, T_q, num_units/h] 274 | out = torch.cat(torch.split(out, 1, dim=0), dim=3).squeeze(0) # [N, T_q, num_units] 275 | 276 | return out,scores 277 | -------------------------------------------------------------------------------- /FRCTR/common/basic_layers.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | """ 3 | @project: RefineCTR 4 | """ 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | 9 | class Skip(nn.Module): 10 | def forward(self, x_emb): 11 | return x_emb, None 12 | 13 | class BasicFRCTR(nn.Module): 14 | def __init__(self, field_dims, embed_dim, FRN=None): 15 | super(BasicFRCTR, self).__init__() 16 | self.embedding = FeaturesEmbedding(field_dims, embed_dim) 17 | self.frn = FRN 18 | if not FRN: 19 | self.frn = Skip() 20 | self.num_fields = len(field_dims) 21 | 22 | def forward(self, x): 23 | raise NotImplemented 24 | 25 | 26 | class FeaturesLinear(nn.Module): 27 | def __init__(self, field_dims, output_dim=1): 28 | super().__init__() 29 | self.fc = torch.nn.Embedding(sum(field_dims), output_dim) 30 | self.bias = torch.nn.Parameter(torch.zeros((output_dim,))) 31 | self.offsets = np.array((0, *np.cumsum(field_dims)[:-1]), dtype=np.long) 32 | 33 | def forward(self, x): 34 | x = x + x.new_tensor(self.offsets).unsqueeze(0) 35 | return torch.sum(self.fc(x), dim=1) + self.bias 36 | 37 | 38 | class FeaturesEmbedding(torch.nn.Module): 39 | def __init__(self, field_dims, embed_dim): 40 | super().__init__() 41 | self.embedding = torch.nn.Embedding(sum(field_dims), embed_dim) 42 | self.offsets = np.array((0, *np.cumsum(field_dims)[:-1]), dtype=np.long) 43 | torch.nn.init.xavier_uniform_(self.embedding.weight.data) 44 | 45 | def forward(self, x): 46 | x = x + x.new_tensor(self.offsets).unsqueeze(0) 47 | return self.embedding(x) 48 | -------------------------------------------------------------------------------- /FRCTR/common/interaction_layer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | """ 3 | @project: RefineCTR 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import itertools 10 | 11 | class FactorizationMachine(nn.Module): 12 | def __init__(self, reduce_sum=True): 13 | super().__init__() 14 | self.reduce_sum = reduce_sum 15 | 16 | def forward(self, x): 17 | square_of_sum = torch.sum(x, dim=1) ** 2 18 | sum_of_square = torch.sum(x ** 2, dim=1) 19 | ix = square_of_sum - sum_of_square 20 | if self.reduce_sum: 21 | ix = torch.sum(ix, dim=1, keepdim=True) 22 | return 0.5 * ix 23 | 24 | 25 | class MultiLayerPerceptron(nn.Module): 26 | def __init__(self, input_dim, embed_dims, dropout=0.5, output_layer=True): 27 | super().__init__() 28 | layers = list() 29 | for embed_dim in embed_dims: 30 | layers.append(torch.nn.Linear(input_dim, embed_dim)) 31 | layers.append(torch.nn.BatchNorm1d(embed_dim)) 32 | layers.append(torch.nn.ReLU()) 33 | layers.append(torch.nn.Dropout(p=dropout)) 34 | input_dim = embed_dim 35 | 36 | if output_layer: 37 | layers.append(torch.nn.Linear(input_dim, 1)) 38 | self.mlp = torch.nn.Sequential(*layers) 39 | self._init_weight_() 40 | 41 | def _init_weight_(self): 42 | for m in self.mlp: 43 | if isinstance(m, nn.Linear): 44 | nn.init.xavier_uniform_(m.weight) 45 | 46 | def forward(self, x): 47 | return self.mlp(x) 48 | 49 | class CrossNetwork(nn.Module): 50 | def __init__(self, input_dim, cn_layers): 51 | super().__init__() 52 | 53 | self.cn_layers = cn_layers 54 | 55 | self.w = torch.nn.ModuleList([ 56 | torch.nn.Linear(input_dim, 1, bias=False) for _ in range(cn_layers) 57 | ]) 58 | self.b = torch.nn.ParameterList([torch.nn.Parameter( 59 | torch.zeros((input_dim,))) for _ in range(cn_layers)]) 60 | 61 | def forward(self, x): 62 | x0 = x 63 | for i in range(self.cn_layers): 64 | xw = self.w[i](x) 65 | x = x0 * xw + self.b[i] + x 66 | return x 67 | 68 | 69 | class CrossNetworkV2(nn.Module): 70 | def __init__(self, input_dim, cn_layers): 71 | super().__init__() 72 | 73 | self.cn_layers = cn_layers 74 | 75 | self.w = torch.nn.ModuleList([ 76 | torch.nn.Linear(input_dim, input_dim, bias=False) for _ in range(cn_layers) 77 | ]) 78 | self.b = torch.nn.ParameterList([torch.nn.Parameter( 79 | torch.zeros((input_dim,))) for _ in range(cn_layers)]) 80 | 81 | def forward(self, x): 82 | x0 = x 83 | for i in range(self.cn_layers): 84 | xw = self.w[i](x) 85 | x = x0 * (xw + self.b[i]) + x 86 | return x 87 | 88 | 89 | class CompressedInteractionNetwork(nn.Module): 90 | def __init__(self, input_dim, cross_layer_sizes, split_half=True): 91 | super().__init__() 92 | self.num_layers = len(cross_layer_sizes) 93 | self.split_half = split_half 94 | self.conv_layers = torch.nn.ModuleList() 95 | prev_dim, fc_input_dim = input_dim, 0 96 | for i in range(self.num_layers): 97 | cross_layer_size = cross_layer_sizes[i] 98 | self.conv_layers.append(torch.nn.Conv1d(input_dim * prev_dim, cross_layer_size, 1, 99 | stride=1, dilation=1, bias=True)) 100 | if self.split_half and i != self.num_layers - 1: 101 | cross_layer_size //= 2 102 | prev_dim = cross_layer_size 103 | fc_input_dim += prev_dim 104 | self.fc = torch.nn.Linear(fc_input_dim, 1) 105 | 106 | def forward(self, x): 107 | xs = list() 108 | x0, h = x.unsqueeze(2), x 109 | for i in range(self.num_layers): 110 | x = x0 * h.unsqueeze(1) 111 | batch_size, f0_dim, fin_dim, embed_dim = x.shape 112 | x = x.view(batch_size, f0_dim * fin_dim, embed_dim) 113 | x = F.relu(self.conv_layers[i](x)) 114 | if self.split_half and i != self.num_layers - 1: 115 | x, h = torch.split(x, x.shape[1] // 2, dim=1) 116 | else: 117 | h = x 118 | xs.append(x) 119 | return self.fc(torch.sum(torch.cat(xs, dim=1), 2)) 120 | 121 | 122 | class InnerProductNetwork(nn.Module): 123 | def __init__(self, num_fields): 124 | super(InnerProductNetwork, self).__init__() 125 | self.row, self.col = list(), list() 126 | for i in range(num_fields - 1): 127 | for j in range(i + 1, num_fields): 128 | self.row.append(i), self.col.append(j) 129 | 130 | def forward(self, x): 131 | return torch.sum(x[:, self.row] * x[:, self.col], dim=2) 132 | 133 | 134 | class OuterProductNetwork(nn.Module): 135 | 136 | def __init__(self, num_fields, embed_dim, kernel_type='mat'): 137 | super().__init__() 138 | num_ix = num_fields * (num_fields - 1) // 2 139 | if kernel_type == 'mat': 140 | kernel_shape = embed_dim, num_ix, embed_dim 141 | elif kernel_type == 'vec': 142 | kernel_shape = num_ix, embed_dim 143 | elif kernel_type == 'num': 144 | kernel_shape = num_ix, 1 145 | else: 146 | raise ValueError('unknown kernel type: ' + kernel_type) 147 | self.kernel_type = kernel_type 148 | self.kernel = torch.nn.Parameter(torch.zeros(kernel_shape)) 149 | torch.nn.init.xavier_uniform_(self.kernel.data) 150 | 151 | self.row, self.col = list(), list() 152 | for i in range(num_fields - 1): 153 | for j in range(i + 1, num_fields): 154 | self.row.append(i), self.col.append(j) 155 | 156 | def forward(self, x): 157 | """ 158 | :param x: Float tensor of size ``(batch_size, num_fields, embed_dim)`` 159 | """ 160 | p, q = x[:, self.row], x[:, self.col] 161 | if self.kernel_type == 'mat': 162 | # p [b,1,num_ix,e] 163 | # kernel [e, num_ix, e] 164 | kp = torch.sum(p.unsqueeze(1) * self.kernel, dim=-1).permute(0, 2, 1) # b,num_ix,e 165 | return torch.sum(kp * q, -1) 166 | else: 167 | # p * q [B,ix,E] * [1,ix,E] => B,ix,E 168 | return torch.sum(p * q * self.kernel.unsqueeze(0), -1) 169 | 170 | 171 | class OuterProductNetwork2(nn.Module): 172 | """ 173 | Outer product with 174 | """ 175 | def __init__(self, num_fields): 176 | super().__init__() 177 | self.row, self.col = list(), list() 178 | for i in range(num_fields - 1): 179 | for j in range(i + 1, num_fields): 180 | self.row.append(i), self.col.append(j) 181 | 182 | def forward(self, x): 183 | p, q = x[:, self.row], x[:, self.col] 184 | # B,IX,E,1 B,IX,1,E 185 | p, q = p.unsqueeze(-1), q.unsqueeze(2) 186 | pq = torch.matmul(p, q) # B,IX,E,E 187 | pq = torch.sum(torch.sum(pq, dim=-1), dim=-1) # B,IX 188 | return pq 189 | 190 | class SenetLayer(nn.Module): 191 | def __init__(self, field_length, ratio=1): 192 | super(SenetLayer, self).__init__() 193 | self.temp_dim = max(1, field_length // ratio) 194 | self.excitation = nn.Sequential( 195 | nn.Linear(field_length, self.temp_dim), 196 | nn.ReLU(), 197 | nn.Linear(self.temp_dim, field_length), 198 | nn.ReLU() 199 | ) 200 | 201 | def forward(self, x_emb): 202 | Z_mean = torch.max(x_emb, dim=2, keepdim=True)[0].transpose(1, 2) 203 | # Z_mean = torch.mean(x_emb, dim=2, keepdim=True).transpose(1, 2) 204 | A_weight = self.excitation(Z_mean).transpose(1, 2) 205 | V_embed = torch.mul(A_weight, x_emb) 206 | return V_embed, A_weight 207 | 208 | class BilinearInteractionLayer(nn.Module): 209 | def __init__(self, filed_size, embedding_size, bilinear_type="interaction"): 210 | super(BilinearInteractionLayer, self).__init__() 211 | self.bilinear_type = bilinear_type 212 | self.bilinear = nn.ModuleList() 213 | 214 | if self.bilinear_type == "all": 215 | self.bilinear = nn.Linear( 216 | embedding_size, embedding_size, bias=False) 217 | 218 | elif self.bilinear_type == "each": 219 | for i in range(filed_size): 220 | self.bilinear.append( 221 | nn.Linear(embedding_size, embedding_size, bias=False)) 222 | 223 | elif self.bilinear_type == "interaction": 224 | for i, j in itertools.combinations(range(filed_size), 2): 225 | self.bilinear.append( 226 | nn.Linear(embedding_size, embedding_size, bias=False)) 227 | else: 228 | raise NotImplementedError 229 | 230 | def forward(self, inputs): 231 | if len(inputs.shape) != 3: 232 | raise ValueError( 233 | "Unexpected inputs dimensions %d, expect to be 3 dimensions" % (len(inputs.shape))) 234 | inputs = torch.split(inputs, 1, dim=1) 235 | if self.bilinear_type == "all": 236 | p = [torch.mul(self.bilinear(v_i), v_j) 237 | for v_i, v_j in itertools.combinations(inputs, 2)] 238 | 239 | elif self.bilinear_type == "each": 240 | p = [torch.mul(self.bilinear[i](inputs[i]), inputs[j]) 241 | for i, j in itertools.combinations(range(len(inputs)), 2)] 242 | 243 | elif self.bilinear_type == "interaction": 244 | p = [torch.mul(bilinear(v[0]), v[1]) 245 | for v, bilinear in zip(itertools.combinations(inputs, 2), self.bilinear)] 246 | else: 247 | raise NotImplementedError 248 | return torch.cat(p, dim=1) 249 | 250 | 251 | class AttentionalFactorizationMachine(nn.Module): 252 | def __init__(self, embed_dim, attn_size, num_fields, dropouts, reduce=True): 253 | super().__init__() 254 | self.attention = torch.nn.Linear(embed_dim, attn_size) 255 | self.projection = torch.nn.Linear(attn_size, 1) 256 | self.fc = torch.nn.Linear(embed_dim, 1) 257 | self.dropouts = dropouts 258 | self.reduce = reduce 259 | self.row, self.col = list(), list() 260 | for i in range(num_fields - 1): 261 | for j in range(i + 1, num_fields): 262 | self.row.append(i), self.col.append(j) 263 | 264 | def forward(self, x): 265 | p, q = x[:, self.row], x[:, self.col] 266 | inner_product = p * q 267 | 268 | attn_scores = F.relu(self.attention(inner_product)) 269 | 270 | attn_scores = F.softmax(self.projection(attn_scores), dim=1) 271 | attn_scores = F.dropout(attn_scores, p=self.dropouts[0]) 272 | 273 | attn_output = torch.sum(attn_scores * inner_product, dim=1) 274 | attn_output = F.dropout(attn_output, p=self.dropouts[1]) 275 | if not self.reduce: 276 | return attn_output 277 | return self.fc(attn_output) -------------------------------------------------------------------------------- /FRCTR/common/layers.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | """ 3 | @project: RefineCTR 4 | """ -------------------------------------------------------------------------------- /FRCTR/common/prediction_layer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | """ 3 | @project: RefineCTR 4 | """ 5 | import torch 6 | from torch import nn 7 | 8 | 9 | 10 | class BasicLR(nn.Module): 11 | def __init__(self, input_dim, sigmoid=False): 12 | super(BasicLR, self).__init__() 13 | self.sigmoid = sigmoid 14 | self.lr = nn.Linear(input_dim, 1, bias=True) 15 | 16 | def forward(self, x): 17 | return x 18 | 19 | 20 | class BasicDNN(nn.Module): 21 | def __init__(self, input_dim, embed_dims, dropout=0.5, output_layer=True, sigmoid=False): 22 | super().__init__() 23 | layers = list() 24 | for embed_dim in embed_dims: 25 | layers.append(torch.nn.Linear(input_dim, embed_dim)) 26 | layers.append(torch.nn.BatchNorm1d(embed_dim)) 27 | layers.append(torch.nn.ReLU()) 28 | layers.append(torch.nn.Dropout(p=dropout)) 29 | input_dim = embed_dim 30 | 31 | if output_layer: 32 | layers.append(torch.nn.Linear(input_dim, 1)) 33 | self.mlp = torch.nn.Sequential(*layers) 34 | self._init_weight_() 35 | 36 | self.sigmoid = sigmoid 37 | 38 | def _init_weight_(self): 39 | for m in self.mlp: 40 | if isinstance(m, nn.Linear): 41 | nn.init.xavier_uniform_(m.weight) 42 | 43 | def forward(self, x): 44 | return self.mlp(x) -------------------------------------------------------------------------------- /FRCTR/common/refinement_layer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | """ 3 | @project: RefineCTR 4 | """ 5 | 6 | -------------------------------------------------------------------------------- /FRCTR/data/Criteo/CriteoDataLoader.py: -------------------------------------------------------------------------------- 1 | import math 2 | import pickle 3 | import shutil 4 | import struct 5 | from collections import defaultdict 6 | from functools import lru_cache 7 | from pathlib import Path 8 | 9 | import lmdb 10 | import numpy as np 11 | import torch.utils.data 12 | from tqdm import tqdm 13 | 14 | 15 | class CriteoDataset(torch.utils.data.Dataset): 16 | """ 17 | Criteo Display Advertising Challenge Dataset 18 | 19 | Data prepration: 20 | * Remove the infrequent features (appearing in less than threshold instances) and treat them as a single feature 21 | * Discretize numerical values by log2 transformation which is proposed by the winner of Criteo Competition 22 | 23 | :param dataset_path: criteo train.txt path. 24 | :param cache_path: lmdb cache path. 25 | :param rebuild_cache: If True, lmdb cache is refreshed. 26 | :param min_threshold: infrequent feature threshold. 27 | 28 | Reference: 29 | https://labs.criteo.com/2014/02/kaggle-display-advertising-challenge-dataset 30 | https://www.csie.ntu.edu.tw/~r01922136/kaggle-2014-criteo.pdf 31 | """ 32 | 33 | def __init__(self, dataset_path=None, cache_path='.criteo', rebuild_cache=False, min_threshold=10): 34 | self.NUM_FEATS = 39 35 | self.NUM_INT_FEATS = 13 36 | self.min_threshold = min_threshold 37 | self.prefix = "../data/criteo/" 38 | if rebuild_cache or not Path(cache_path).exists(): 39 | shutil.rmtree(cache_path, ignore_errors=True) 40 | if dataset_path is None: 41 | raise ValueError('create cache: failed: dataset_path is None') 42 | self.__build_cache(dataset_path, cache_path) 43 | self.env = lmdb.open(cache_path, create=False, lock=False, readonly=True) 44 | with self.env.begin(write=False) as txn: 45 | self.length = txn.stat()['entries'] - 1 46 | self.field_dims = np.frombuffer(txn.get(b'field_dims'), dtype=np.uint32) 47 | 48 | def __getitem__(self, index): 49 | with self.env.begin(write=False) as txn: 50 | np_array = np.frombuffer( 51 | txn.get(struct.pack('>I', index)), dtype=np.uint32).astype(dtype=np.long) 52 | return np_array[1:], np_array[0] 53 | 54 | def __len__(self): 55 | # Must be implemented 56 | return self.length 57 | 58 | def __build_cache(self, path, cache_path): 59 | temp_path = self.prefix + "train.txt" 60 | 61 | feat_mapper, defaults = self.__get_feat_mapper(temp_path) 62 | 63 | with lmdb.open(cache_path, map_size=int(1e11)) as env: 64 | field_dims = np.zeros(self.NUM_FEATS, dtype=np.uint32) 65 | for i, fm in feat_mapper.items(): 66 | field_dims[i - 1] = len(fm) + 1 67 | 68 | # save field_dims 69 | with env.begin(write=True) as txn: 70 | txn.put(b'field_dims', field_dims.tobytes()) 71 | 72 | for buffer in self.__yield_buffer(path, feat_mapper, defaults): 73 | with env.begin(write=True) as txn: 74 | for key, value in buffer: 75 | txn.put(key, value) 76 | 77 | def __read_train_all_feats(self): 78 | return pickle.load(open(self.prefix + "train_all_feat811.pkl", "rb")) 79 | 80 | def __get_feat_mapper(self, path): 81 | 82 | feat_cnts = defaultdict(lambda: defaultdict(int)) 83 | with open(path) as f: 84 | pbar = tqdm(f, mininterval=1, smoothing=0.1) 85 | pbar.set_description('Create criteo dataset cache: counting features') 86 | for line in pbar: 87 | values = line.rstrip('\n').split('\t') 88 | if len(values) != self.NUM_FEATS + 1: 89 | continue 90 | 91 | for i in range(1, self.NUM_INT_FEATS + 1): 92 | feat_cnts[i][convert_numeric_feature(values[i])] += 1 93 | 94 | for i in range(self.NUM_INT_FEATS + 1, self.NUM_FEATS + 1): 95 | feat_cnts[i][values[i]] += 1 96 | 97 | feat_mapper = {i: {feat for feat, c in cnt.items() if c >= self.min_threshold} for i, cnt in feat_cnts.items()} 98 | feat_mapper = {i: {feat: idx for idx, feat in enumerate(cnt)} for i, cnt in feat_mapper.items()} 99 | defaults = {i: len(cnt) for i, cnt in feat_mapper.items()} 100 | 101 | f = open(self.prefix + "train_all_feat811.pkl", "wb") 102 | pickle.dump((feat_mapper, defaults), f) 103 | 104 | return feat_mapper, defaults 105 | 106 | def __yield_buffer(self, path, feat_mapper, defaults, buffer_size=int(1e5)): 107 | item_idx = 0 108 | buffer = list() 109 | with open(path) as f: 110 | pbar = tqdm(f, mininterval=1, smoothing=0.1) 111 | pbar.set_description('Create criteo dataset cache: setup lmdb') 112 | for line in pbar: 113 | values = line.rstrip('\n').split('\t') 114 | if len(values) != self.NUM_FEATS + 1: 115 | continue 116 | np_array = np.zeros(self.NUM_FEATS + 1, dtype=np.uint32) 117 | np_array[0] = int(values[0]) 118 | for i in range(1, self.NUM_INT_FEATS + 1): 119 | np_array[i] = feat_mapper[i].get(convert_numeric_feature(values[i]), defaults[i]) 120 | 121 | for i in range(self.NUM_INT_FEATS + 1, self.NUM_FEATS + 1): 122 | np_array[i] = feat_mapper[i].get(values[i], defaults[i]) 123 | buffer.append((struct.pack('>I', item_idx), np_array.tobytes())) 124 | item_idx += 1 125 | if item_idx % buffer_size == 0: 126 | yield buffer 127 | buffer.clear() 128 | yield buffer 129 | 130 | 131 | @lru_cache(maxsize=None) 132 | def convert_numeric_feature(val: str): 133 | if val == '': 134 | return 'NULL' 135 | v = int(val) 136 | if v > 2: 137 | return str(int(math.log(v) ** 2)) 138 | else: 139 | return str(v) 140 | 141 | 142 | def get_criteo_811(train_path="train.txt", batch_size=2048): 143 | # the test_path maybe null, if it is, we need to split the train dataset 144 | print("Start loading criteo data....") 145 | prefix = "../data/criteo/" 146 | train_path = prefix + train_path 147 | dataset = CriteoDataset(dataset_path=train_path, cache_path=prefix + ".criteoall") 148 | all_length = len(dataset) 149 | print(all_length) 150 | 151 | # 8:1:1 152 | test_size = int(0.1 * all_length) 153 | train_size = all_length - test_size * 2 154 | 155 | train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size * 2], 156 | generator=torch.Generator().manual_seed(2022)) 157 | test_dataset, valid_dataset = torch.utils.data.random_split(test_dataset, [test_size, test_size], 158 | generator=torch.Generator().manual_seed(2022)) 159 | print("train_dataset length:", len(train_dataset)) 160 | print("valid_dataset length:", len(valid_dataset)) 161 | print("test_dataset length:", len(test_dataset)) 162 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4) 163 | valid_Loader = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size, shuffle=True, num_workers=4) 164 | test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4) 165 | 166 | field_dims = dataset.field_dims 167 | return field_dims, train_loader, valid_Loader, test_loader 168 | 169 | 170 | """ 171 | the full data: train.txt, which contains over 45,000,000 simple, 13 numerical features and 26 categorical feature 172 | [ 49 101 126 45 223 118 84 76 95 9 173 | 30 40 75 1458 555 193949 138801 306 19 11970 174 | 634 4 42646 5178 192773 3175 27 11422 181075 11 175 | 4654 2032 5 189657 18 16 59697 86 45571] 176 | """ 177 | -------------------------------------------------------------------------------- /FRCTR/data/Criteo/__pycache__/CriteoDataLoader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codectr/RefineCTR/edd3965120f54cf4e799b179d4383fda2d7ef65e/FRCTR/data/Criteo/__pycache__/CriteoDataLoader.cpython-36.pyc -------------------------------------------------------------------------------- /FRCTR/data/Frappe/FrappeDataLoader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | ''' 4 | @Author:wangfy 5 | @project:DL_recommend 6 | @Time:2020/4/29 9:42 上午 7 | ''' 8 | 9 | import os 10 | import pickle 11 | 12 | import pandas as pd 13 | import torch 14 | import tqdm 15 | 16 | 17 | class LoadData811(): 18 | def __init__(self, path="./Data/", dataset="frappe", loss_type="square_loss"): 19 | self.dataset = dataset 20 | self.loss_type = loss_type 21 | self.path = path + dataset + "/" 22 | self.trainfile = self.path + dataset + ".train.libfm" 23 | self.testfile = self.path + dataset + ".test.libfm" 24 | self.validationfile = self.path + dataset + ".validation.libfm" 25 | self.features_M = {} 26 | self.construct_df() 27 | 28 | def construct_df(self): 29 | self.data_train = pd.read_table(self.trainfile, sep=" ", header=None, engine='python') 30 | self.data_test = pd.read_table(self.testfile, sep=" ", header=None, engine="python") 31 | self.data_valid = pd.read_table(self.validationfile, sep=" ", header=None, engine="python") 32 | 33 | for i in self.data_test.columns[1:]: 34 | self.data_test[i] = self.data_test[i].apply(lambda x: int(x.split(":")[0])) 35 | self.data_train[i] = self.data_train[i].apply(lambda x: int(x.split(":")[0])) 36 | self.data_valid[i] = self.data_valid[i].apply(lambda x: int(x.split(":")[0])) 37 | 38 | self.all_data = pd.concat([self.data_train, self.data_test, self.data_valid]) 39 | self.field_dims = [] 40 | 41 | for i in self.all_data.columns[1:]: 42 | maps = {val: k for k, val in enumerate(set(self.all_data[i]))} 43 | # self.data_test[i] = self.data_test[i].map(maps) 44 | # self.data_train[i] = self.data_train[i].map(maps) 45 | # self.data_valid[i] = self.data_valid[i].map(maps) 46 | self.all_data[i] = self.all_data[i].map(maps) 47 | self.features_M[i] = maps 48 | self.field_dims.append(len(set(self.all_data[i]))) 49 | 50 | self.all_data[0] = self.all_data[0].apply(lambda x: max(x, 0)) 51 | # self.data_test[0] = self.data_test[0].apply(lambda x: max(x, 0)) 52 | # self.data_train[0] = self.data_train[0].apply(lambda x: max(x, 0)) 53 | # self.data_valid[0] = self.data_valid[0].apply(lambda x: max(x, 0)) 54 | 55 | 56 | class LoadData(): 57 | def __init__(self, path="./Data/", dataset="frappe", loss_type="square_loss"): 58 | self.dataset = dataset 59 | self.loss_type = loss_type 60 | self.path = path + dataset + "/" 61 | self.trainfile = self.path + dataset + ".train.libfm" 62 | self.testfile = self.path + dataset + ".test.libfm" 63 | self.validationfile = self.path + dataset + ".validation.libfm" 64 | self.features_M = {} 65 | self.construct_df() 66 | 67 | def construct_df(self): 68 | self.data_train = pd.read_table(self.trainfile, sep=" ", header=None, engine='python') 69 | self.data_test = pd.read_table(self.testfile, sep=" ", header=None, engine="python") 70 | self.data_valid = pd.read_table(self.validationfile, sep=" ", header=None, engine="python") 71 | 72 | for i in self.data_test.columns[1:]: 73 | self.data_test[i] = self.data_test[i].apply(lambda x: int(x.split(":")[0])) 74 | self.data_train[i] = self.data_train[i].apply(lambda x: int(x.split(":")[0])) 75 | self.data_valid[i] = self.data_valid[i].apply(lambda x: int(x.split(":")[0])) 76 | 77 | self.all_data = pd.concat([self.data_train, self.data_test, self.data_valid]) 78 | self.field_dims = [] 79 | 80 | for i in self.all_data.columns[1:]: 81 | maps = {val: k for k, val in enumerate(set(self.all_data[i]))} 82 | self.data_test[i] = self.data_test[i].map(maps) 83 | self.data_train[i] = self.data_train[i].map(maps) 84 | self.data_valid[i] = self.data_valid[i].map(maps) 85 | self.features_M[i] = maps 86 | self.field_dims.append(len(set(self.all_data[i]))) 87 | self.data_test[0] = self.data_test[0].apply(lambda x: max(x, 0)) 88 | self.data_train[0] = self.data_train[0].apply(lambda x: max(x, 0)) 89 | self.data_valid[0] = self.data_valid[0].apply(lambda x: max(x, 0)) 90 | 91 | 92 | class LoadDataMSE(): 93 | def __init__(self, path="./Data/", dataset="frappe", loss_type="square_loss"): 94 | self.dataset = dataset 95 | self.loss_type = loss_type 96 | self.path = path + dataset + "/" 97 | self.trainfile = self.path + dataset + ".train.libfm" 98 | self.testfile = self.path + dataset + ".test.libfm" 99 | self.validationfile = self.path + dataset + ".valid.libfm" 100 | self.features_M = {} 101 | self.construct_df() 102 | 103 | # self.Train_data, self.Validation_data, self.Test_data = self.construct_data( loss_type ) 104 | 105 | def construct_df(self): 106 | self.data_train = pd.read_table(self.trainfile, sep=" ", header=None, engine='python') 107 | self.data_test = pd.read_table(self.testfile, sep=" ", header=None, engine="python") 108 | self.data_valid = pd.read_table(self.validationfile, sep=" ", header=None, engine="python") 109 | # 第一列是标签,y 110 | 111 | for i in self.data_test.columns[1:]: 112 | self.data_test[i] = self.data_test[i].apply(lambda x: int(x.split(":")[0])) 113 | self.data_train[i] = self.data_train[i].apply(lambda x: int(x.split(":")[0])) 114 | self.data_valid[i] = self.data_valid[i].apply(lambda x: int(x.split(":")[0])) 115 | 116 | self.all_data = pd.concat([self.data_train, self.data_test, self.data_valid]) 117 | self.field_dims = [] 118 | 119 | for i in self.all_data.columns[1:]: 120 | # if self.dataset != "frappe": 121 | # maps = {} 122 | maps = {val: k for k, val in enumerate(set(self.all_data[i]))} 123 | self.data_test[i] = self.data_test[i].map(maps) 124 | self.data_train[i] = self.data_train[i].map(maps) 125 | self.data_valid[i] = self.data_valid[i].map(maps) 126 | self.features_M[i] = maps 127 | self.field_dims.append(len(set(self.all_data[i]))) 128 | # -1 改成 0 129 | # self.data_test[0] = self.data_test[0].apply(lambda x: max(x, 0)) 130 | # self.data_train[0] = self.data_train[0].apply(lambda x: max(x, 0)) 131 | # self.data_valid[0] = self.data_valid[0].apply(lambda x: max(x, 0)) 132 | 133 | 134 | class RecData(): 135 | def __init__(self, all_data): 136 | self.data_df = all_data 137 | 138 | def __len__(self): 139 | return len(self.data_df) 140 | 141 | def __getitem__(self, idx): 142 | x = self.data_df.iloc[idx].values[1:] 143 | y1 = self.data_df.iloc[idx].values[0] 144 | return x, y1 145 | 146 | 147 | def getfrappe_loader811(path="../data/", dataset="frappe", num_ng=4, batch_size=256): 148 | print("start load frappe dataset") 149 | AllDataF = LoadData811(path=path, dataset=dataset) 150 | all_dataset = RecData(AllDataF.all_data) 151 | 152 | train_size = int(0.9 * len(all_dataset)) 153 | test_size = len(all_dataset) - train_size 154 | # 8:1:1 155 | train_dataset, test_dataset = torch.utils.data.random_split(all_dataset, [train_size, test_size]) 156 | train_dataset, valid_dataset = torch.utils.data.random_split(train_dataset, [train_size - test_size, test_size]) 157 | 158 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, 159 | shuffle=True, num_workers=4, drop_last=True) 160 | valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size, 161 | shuffle=True, num_workers=4) 162 | test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, 163 | shuffle=False, num_workers=4) 164 | print(len(train_loader)) 165 | print(len(valid_loader)) 166 | print(len(test_loader)) 167 | return AllDataF.field_dims, train_loader, valid_loader, test_loader 168 | 169 | 170 | def getdataloader_frappe(path="../data/", dataset="frappe", num_ng=4, batch_size=256): 171 | print(os.getcwd()) 172 | print("start load frappe dataset") 173 | DataF = LoadData(path=path, dataset=dataset) 174 | # 7:2:1 175 | datatest = RecData(DataF.data_test) 176 | datatrain = RecData(DataF.data_train) 177 | datavalid = RecData(DataF.data_valid) 178 | print("datatest", len(datatest)) 179 | print("datatrain", len(datatrain)) 180 | print("datavalid", len(datavalid)) 181 | trainLoader = torch.utils.data.DataLoader(datatrain, batch_size=batch_size, shuffle=True, num_workers=8, 182 | pin_memory=True, drop_last=True) 183 | validLoader = torch.utils.data.DataLoader(datavalid, batch_size=batch_size, shuffle=False, num_workers=4, 184 | pin_memory=True) 185 | testLoader = torch.utils.data.DataLoader(datatest, batch_size=batch_size, shuffle=False, num_workers=4, 186 | pin_memory=True) 187 | return DataF.field_dims, trainLoader, validLoader, testLoader 188 | 189 | 190 | if __name__ == '__main__': 191 | 192 | field_dims, trainLoader, validLoader, testLoader = getfrappe_loader811(path="../", batch_size=256) 193 | for _ in tqdm.tqdm(trainLoader): 194 | pass 195 | it = iter(trainLoader) 196 | print(next(it)[0]) 197 | print(field_dims) 198 | -------------------------------------------------------------------------------- /FRCTR/data/Frappe/__pycache__/FrappeDataLoader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codectr/RefineCTR/edd3965120f54cf4e799b179d4383fda2d7ef65e/FRCTR/data/Frappe/__pycache__/FrappeDataLoader.cpython-36.pyc -------------------------------------------------------------------------------- /FRCTR/data/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | """ 3 | @project: RefineCTR 4 | """ 5 | 6 | from .Frappe.FrappeDataLoader import getdataloader_frappe, getfrappe_loader811 7 | from .Criteo.CriteoDataLoader import get_criteo_811 8 | -------------------------------------------------------------------------------- /FRCTR/data/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codectr/RefineCTR/edd3965120f54cf4e799b179d4383fda2d7ef65e/FRCTR/data/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /FRCTR/model_zoo/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | """ 3 | @project: RefineCTR 4 | """ 5 | import sys 6 | sys.path.append("../") 7 | 8 | from .fm import FMFrn 9 | from .deepfm import DeepFMFrn, DeepFMFrnP 10 | from .dcn import CNFrn, DCNFrn, DCNFrnP 11 | from .dcnv2 import CN2Frn, DCNV2Frn, DCNV2FrnP 12 | from .afnp import AFNFrn, AFNPlusFrn, AFNPlusFrnP 13 | from .xdeepfm import CINFrn, xDeepFMFrn, xDeepFMFrnP 14 | 15 | from .fibinet import FiBiNetFrn 16 | from .fwfm import FwFMFrn 17 | from .fnn import FNNFrn 18 | from .nfm import NFMFrn 19 | from .fint import FINTFrn -------------------------------------------------------------------------------- /FRCTR/model_zoo/afm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | """ 3 | @project:RefineCTR 4 | """ 5 | 6 | import torch 7 | 8 | from FRCTR.common import FeaturesLinear, FeaturesEmbedding, AttentionalFactorizationMachine, BasicFRCTR 9 | 10 | 11 | class AFM(torch.nn.Module): 12 | def __init__(self, field_dims, embed_dim, attn_size, dropouts=(0.5, 0.5)): 13 | super().__init__() 14 | self.num_fields = len(field_dims) 15 | self.embedding = FeaturesEmbedding(field_dims, embed_dim) 16 | self.linear = FeaturesLinear(field_dims) 17 | self.afm = AttentionalFactorizationMachine(embed_dim, attn_size, 18 | self.num_fields, dropouts=dropouts, 19 | reduce=True) 20 | 21 | def forward(self, x): 22 | x_emb = self.embedding(x) 23 | cross_term = self.afm(x_emb) 24 | pred_y = self.linear(x) + cross_term 25 | return pred_y 26 | 27 | 28 | class AFMFrn(BasicFRCTR): 29 | def __init__(self, field_dims, embed_dim, FRN=None, attn_size=16, dropouts=(0.5, 0.5)): 30 | super().__init__(field_dims, embed_dim, FRN) 31 | self.num_fields = len(field_dims) 32 | self.linear = FeaturesLinear(field_dims) 33 | self.afm = AttentionalFactorizationMachine(embed_dim, attn_size, 34 | self.num_fields, dropouts=dropouts, 35 | reduce=True) 36 | 37 | def forward(self, x): 38 | x_emb = self.embedding(x) 39 | x_emb, weight = self.frn(x_emb) 40 | cross_term = self.afm(x_emb) 41 | pred_y = self.linear(x) + cross_term 42 | return pred_y 43 | -------------------------------------------------------------------------------- /FRCTR/model_zoo/afnp.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | """ 3 | @project:RefineCTR 4 | """ 5 | 6 | import math 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | from torch import nn 11 | 12 | from FRCTR.common import FeaturesEmbedding, MultiLayerPerceptron, BasicFRCTR, FeaturesLinear 13 | 14 | 15 | class LNN(torch.nn.Module): 16 | """ 17 | A pytorch implementation of LNN layer 18 | Input shape 19 | - A 3D tensor with shape: ``(batch_size,field_size,embedding_size)``. 20 | Output shape 21 | - 2D tensor with shape:``(batch_size,LNN_dim*embedding_size)``. 22 | Arguments 23 | - **in_features** : Embedding of feature. 24 | - **num_fields**: int.The field size of feature. 25 | - **LNN_dim**: int.The number of Logarithmic neuron. 26 | - **bias**: bool.Whether or not use bias in LNN. 27 | """ 28 | 29 | def __init__(self, num_fields, embed_dim, LNN_dim, bias=False): 30 | super(LNN, self).__init__() 31 | self.num_fields = num_fields 32 | self.embed_dim = embed_dim 33 | self.LNN_dim = LNN_dim 34 | self.lnn_output_dim = LNN_dim * embed_dim 35 | 36 | self.weight = torch.nn.Parameter(torch.Tensor(LNN_dim, num_fields)) 37 | if bias: 38 | self.bias = torch.nn.Parameter(torch.Tensor(LNN_dim, embed_dim)) 39 | else: 40 | self.register_parameter('bias', None) 41 | self.reset_parameters() 42 | 43 | def reset_parameters(self): 44 | stdv = 1. / math.sqrt(self.weight.size(1)) 45 | self.weight.data.uniform_(-stdv, stdv) 46 | if self.bias is not None: 47 | self.bias.data.uniform_(-stdv, stdv) 48 | 49 | def forward(self, x): 50 | """ 51 | param x: Long tensor of size ``(batch_size, num_fields, embedding_size)`` 52 | """ 53 | # Computes the element-wise absolute value of the given input tensor. 54 | embed_x_abs = torch.abs(x) 55 | embed_x_afn = torch.add(embed_x_abs, 1e-7) 56 | # Logarithmic Transformation 57 | # torch.log1p 58 | embed_x_log = torch.log1p(embed_x_afn) 59 | lnn_out = torch.matmul(self.weight, embed_x_log) 60 | if self.bias is not None: 61 | lnn_out += self.bias 62 | 63 | # torch.expm1 64 | lnn_exp = torch.expm1(lnn_out) 65 | output = F.relu(lnn_exp).contiguous().view(-1, self.lnn_output_dim) 66 | return output 67 | 68 | 69 | class AFN(torch.nn.Module): 70 | def __init__(self, field_dims, embed_dim, LNN_dim=10, mlp_dims=(400, 400, 400), dropouts=(0.5, 0.5)): 71 | super().__init__() 72 | self.linear = FeaturesLinear(field_dims) 73 | self.num_fields = len(field_dims) 74 | self.embedding = FeaturesEmbedding(field_dims, embed_dim) 75 | self.LNN_dim = LNN_dim 76 | self.LNN_output_dim = self.LNN_dim * embed_dim 77 | self.LNN = LNN(self.num_fields, embed_dim, LNN_dim) 78 | 79 | self.mlp = MultiLayerPerceptron(self.LNN_output_dim, mlp_dims, dropouts[0]) 80 | 81 | def forward(self, x): 82 | 83 | x_emb = self.embedding(x) 84 | 85 | lnn_out = self.LNN(x_emb) 86 | 87 | pred_y = self.mlp(lnn_out) + self.linear(x) 88 | return pred_y 89 | 90 | 91 | class AFNPlus(torch.nn.Module): 92 | def __init__(self, field_dims, embed_dim, LNN_dim=10, mlp_dims=(400, 400, 400), 93 | mlp_dims2=(400, 400, 400), dropouts=(0.5, 0.5)): 94 | super().__init__() 95 | self.num_fields = len(field_dims) 96 | self.linear = FeaturesLinear(field_dims) # Linear 97 | self.embedding = FeaturesEmbedding(field_dims, embed_dim) # Embedding 98 | 99 | self.LNN_dim = LNN_dim 100 | self.LNN_output_dim = self.LNN_dim * embed_dim 101 | self.LNN = LNN(self.num_fields, embed_dim, LNN_dim) 102 | 103 | self.mlp = MultiLayerPerceptron(self.LNN_output_dim, mlp_dims, dropouts[0]) 104 | 105 | self.embed_output_dim = len(field_dims) * embed_dim 106 | self.mlp2 = MultiLayerPerceptron(self.embed_output_dim, mlp_dims2, dropouts[1]) 107 | 108 | self.lr = torch.nn.Linear(2, 1, bias=True) 109 | 110 | def forward(self, x): 111 | """ 112 | param x: Long tensor of size ``(batch_size, num_fields)`` 113 | """ 114 | x_emb = self.embedding(x) 115 | 116 | lnn_out = self.LNN(x_emb) 117 | x_dnn = self.mlp2(x_emb.view(-1, self.embed_output_dim)) 118 | x_lnn = self.mlp(lnn_out) 119 | pred_y = self.linear(x) + x_lnn + x_dnn 120 | return pred_y 121 | 122 | 123 | class AFNFrn(BasicFRCTR): 124 | def __init__(self, field_dims, embed_dim, FRN=None, LNN_dim=16, mlp_dims=(400, 400, 400), dropouts=(0.5, 0.5)): 125 | super().__init__(field_dims, embed_dim, FRN) 126 | self.linear = FeaturesLinear(field_dims) 127 | self.num_fields = len(field_dims) 128 | self.LNN_dim = LNN_dim 129 | self.LNN_output_dim = self.LNN_dim * embed_dim 130 | self.LNN = LNN(self.num_fields, embed_dim, LNN_dim) 131 | self.mlp = MultiLayerPerceptron(self.LNN_output_dim, mlp_dims, dropouts[0], output_layer=True) 132 | 133 | def forward(self, x): 134 | """ 135 | param x: Long tensor of size ``(batch_size, num_fields)`` 136 | """ 137 | x_emb = self.embedding(x) 138 | x_emb, _ = self.frn(x_emb) 139 | 140 | lnn_out = self.LNN(x_emb) 141 | 142 | pred_y = self.mlp(lnn_out) + self.linear(x) 143 | return pred_y 144 | 145 | 146 | class AFNPlusFrn(BasicFRCTR): 147 | def __init__(self, field_dims, embed_dim, FRN=None, LNN_dim=10, mlp_dims=(400, 400, 400), 148 | mlp_dims2=(400, 400, 400), dropouts=(0.5, 0.5)): 149 | super().__init__(field_dims, embed_dim, FRN) 150 | self.linear = FeaturesLinear(field_dims) 151 | self.num_fields = len(field_dims) 152 | 153 | self.LNN_dim = LNN_dim 154 | self.LNN_output_dim = self.LNN_dim * embed_dim 155 | self.LNN = LNN(self.num_fields, embed_dim, LNN_dim) 156 | 157 | self.mlp = MultiLayerPerceptron(self.LNN_output_dim, mlp_dims, dropouts[0]) 158 | 159 | self.embed_output_dim = len(field_dims) * embed_dim 160 | self.mlp2 = MultiLayerPerceptron(self.embed_output_dim, mlp_dims2, dropouts[1]) 161 | 162 | self.lr = torch.nn.Linear(2, 1, bias=True) 163 | 164 | def forward(self, x): 165 | """ 166 | param x: Long tensor of size ``(batch_size, num_fields)`` 167 | """ 168 | x_emb = self.embedding(x) 169 | x_emb, _ = self.frn1(x_emb) 170 | 171 | lnn_out = self.LNN(x_emb) 172 | x_dnn = self.mlp2(x_emb.view(-1, self.embed_output_dim)) 173 | x_lnn = self.mlp(lnn_out) 174 | pred_y = self.linear(x) + x_lnn + x_dnn 175 | return pred_y 176 | 177 | 178 | class AFNPlusFrnP(nn.Module): 179 | def __init__(self, field_dims, embed_dim, FRN1=None, FRN2=None, LNN_dim=10, mlp_dims=(400, 400, 400), 180 | mlp_dims2=(400, 400, 400), dropouts=(0.5, 0.5)): 181 | super().__init__() 182 | self.num_fields = len(field_dims) 183 | self.linear = FeaturesLinear(field_dims) # Linear 184 | self.embedding = FeaturesEmbedding(field_dims, embed_dim) # Embedding 185 | 186 | if not FRN1 or not FRN2: 187 | raise ValueError("Feature Refinement Network is None") 188 | self.frn1 = FRN1 189 | self.frn2 = FRN2 190 | 191 | self.LNN_dim = LNN_dim 192 | self.LNN_output_dim = self.LNN_dim * embed_dim 193 | self.LNN = LNN(self.num_fields, embed_dim, LNN_dim) 194 | 195 | self.mlp = MultiLayerPerceptron(self.LNN_output_dim, mlp_dims, dropouts[0]) 196 | 197 | self.embed_output_dim = len(field_dims) * embed_dim 198 | self.mlp2 = MultiLayerPerceptron(self.embed_output_dim, mlp_dims2, dropouts[1]) 199 | 200 | self.lr = torch.nn.Linear(2, 1, bias=True) 201 | 202 | def forward(self, x): 203 | x_emb = self.embedding(x) 204 | x_emb1, weight1 = self.frn1(x_emb) 205 | x_emb2, weight2 = self.frn2(x_emb) 206 | 207 | lnn_out = self.LNN(x_emb1) 208 | x_lnn = self.mlp(lnn_out) 209 | 210 | x_dnn = self.mlp2(x_emb2.reshape(-1, self.embed_output_dim)) 211 | pred_y = self.linear(x) + x_lnn + x_dnn 212 | # pred_y = self.lr(torch.cat([x_lnn, x_dnn], dim=1)) 213 | return pred_y 214 | -------------------------------------------------------------------------------- /FRCTR/model_zoo/autoint.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | """ 3 | @project:RefineCTR 4 | """ 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | 9 | from FRCTR.common import FeaturesEmbedding, FeaturesLinear, MultiLayerPerceptron 10 | 11 | 12 | class AutoIntPlus(torch.nn.Module): 13 | def __init__(self, field_dims, embed_dim, atten_embed_dim=64, num_heads=2, 14 | num_layers=3, mlp_dims=(400, 400, 400), dropouts=(0.5, 0.5), has_residual=True): 15 | super().__init__() 16 | self.num_fields = len(field_dims) 17 | self.linear = FeaturesLinear(field_dims) 18 | self.embedding = FeaturesEmbedding(field_dims, embed_dim) 19 | self.atten_embedding = torch.nn.Linear(embed_dim, atten_embed_dim) 20 | self.embed_output_dim = len(field_dims) * embed_dim 21 | self.atten_output_dim = len(field_dims) * atten_embed_dim 22 | self.has_residual = has_residual 23 | self.mlp = MultiLayerPerceptron(self.embed_output_dim, mlp_dims, dropouts[1]) 24 | 25 | self.self_attns = torch.nn.ModuleList([ 26 | torch.nn.MultiheadAttention(atten_embed_dim, num_heads, dropout=dropouts[0]) for _ in range(num_layers) 27 | ]) 28 | self.attn_fc = torch.nn.Linear(self.atten_output_dim, 1) 29 | if self.has_residual: 30 | self.V_res_embedding = torch.nn.Linear(embed_dim, atten_embed_dim) 31 | 32 | def forward(self, x): 33 | x_emb = self.embedding(x) 34 | atten_x = self.atten_embedding(x_emb) 35 | 36 | cross_term = atten_x.transpose(0, 1) 37 | 38 | for self_attn in self.self_attns: 39 | cross_term, _ = self_attn(cross_term, cross_term, cross_term) 40 | 41 | cross_term = cross_term.transpose(0, 1) 42 | if self.has_residual: 43 | V_res = self.V_res_embedding(x_emb) 44 | cross_term += V_res 45 | 46 | cross_term = F.relu(cross_term).contiguous().view(-1, self.atten_output_dim) 47 | 48 | pred_y = self.linear(x) + self.attn_fc(cross_term) + self.mlp(x_emb.view(-1, self.embed_output_dim)) 49 | return pred_y 50 | 51 | class AutoIntPlusFrn(torch.nn.Module): 52 | def __init__(self, field_dims, embed_dim, FRN=None, atten_embed_dim=64, num_heads=2, 53 | num_layers=3, mlp_dims=(400, 400, 400), dropouts=(0.5, 0.5), has_residual=True): 54 | super().__init__() 55 | self.num_fields = len(field_dims) 56 | self.linear = FeaturesLinear(field_dims) 57 | self.embedding = FeaturesEmbedding(field_dims, embed_dim) 58 | self.atten_embedding = torch.nn.Linear(embed_dim, atten_embed_dim) 59 | self.embed_output_dim = len(field_dims) * embed_dim 60 | self.atten_output_dim = len(field_dims) * atten_embed_dim 61 | self.has_residual = has_residual 62 | self.mlp = MultiLayerPerceptron(self.embed_output_dim, mlp_dims, dropouts[1]) 63 | 64 | self.self_attns = torch.nn.ModuleList([ 65 | torch.nn.MultiheadAttention(atten_embed_dim, num_heads, dropout=dropouts[0]) for _ in range(num_layers) 66 | ]) 67 | self.attn_fc = torch.nn.Linear(self.atten_output_dim, 1) 68 | if self.has_residual: 69 | self.V_res_embedding = torch.nn.Linear(embed_dim, atten_embed_dim) 70 | self.frn = FRN 71 | def forward(self, x): 72 | x_emb = self.embedding(x) 73 | x_emb, _ = self.frn(x_emb) 74 | atten_x = self.atten_embedding(x_emb) 75 | 76 | cross_term = atten_x.transpose(0, 1) 77 | 78 | for self_attn in self.self_attns: 79 | cross_term, _ = self_attn(cross_term, cross_term, cross_term) 80 | 81 | cross_term = cross_term.transpose(0, 1) 82 | if self.has_residual: 83 | V_res = self.V_res_embedding(x_emb) 84 | cross_term += V_res 85 | 86 | cross_term = F.relu(cross_term).contiguous().view(-1, self.atten_output_dim) 87 | 88 | pred_y = self.linear(x) + self.attn_fc(cross_term) + self.mlp(x_emb.view(-1, self.embed_output_dim)) 89 | return pred_y 90 | 91 | class AutoIntPlusFrnP(torch.nn.Module): 92 | def __init__(self, field_dims, embed_dim, FRN1=None, FRN2=None, atten_embed_dim=64, num_heads=2, 93 | num_layers=3, mlp_dims=(400, 400, 400), dropouts=(0.5, 0.5), has_residual=True): 94 | super().__init__() 95 | self.num_fields = len(field_dims) 96 | self.linear = FeaturesLinear(field_dims) 97 | self.embedding = FeaturesEmbedding(field_dims, embed_dim) 98 | self.atten_embedding = torch.nn.Linear(embed_dim, atten_embed_dim) 99 | self.embed_output_dim = len(field_dims) * embed_dim 100 | self.atten_output_dim = len(field_dims) * atten_embed_dim 101 | self.has_residual = has_residual 102 | self.mlp = MultiLayerPerceptron(self.embed_output_dim, mlp_dims, dropouts[1]) 103 | 104 | self.self_attns = torch.nn.ModuleList([ 105 | torch.nn.MultiheadAttention(atten_embed_dim, num_heads, dropout=dropouts[0]) for _ in range(num_layers) 106 | ]) 107 | self.attn_fc = torch.nn.Linear(self.atten_output_dim, 1) 108 | if self.has_residual: 109 | self.V_res_embedding = torch.nn.Linear(embed_dim, atten_embed_dim) 110 | self.frn1 = FRN1 111 | self.frn2 = FRN2 112 | 113 | def forward(self, x): 114 | x_emb = self.embedding(x) 115 | x_emb1, _ = self.frn1(x_emb) 116 | x_emb2, _ = self.frn2(x_emb) 117 | x_mlp = self.mlp(x_emb2.view(-1, self.embed_output_dim)) 118 | 119 | atten_x = self.atten_embedding(x_emb1) 120 | cross_term = atten_x.transpose(0, 1) 121 | 122 | for self_attn in self.self_attns: 123 | cross_term, _ = self_attn(cross_term, cross_term, cross_term) 124 | 125 | cross_term = cross_term.transpose(0, 1) 126 | if self.has_residual: 127 | V_res = self.V_res_embedding(x_emb) 128 | cross_term += V_res 129 | 130 | cross_term = F.relu(cross_term).contiguous().view(-1, self.atten_output_dim) 131 | 132 | pred_y = self.linear(x) + self.attn_fc(cross_term) + x_mlp 133 | return pred_y 134 | 135 | 136 | class AutoInt(torch.nn.Module): 137 | 138 | def __init__(self, field_dims, embed_dim, atten_embed_dim=32, num_heads=4, 139 | num_layers=3, dropouts=(0.5, 0.5), has_residual=True): 140 | super().__init__() 141 | self.num_fields = len(field_dims) 142 | self.linear = FeaturesLinear(field_dims) 143 | self.embedding = FeaturesEmbedding(field_dims, embed_dim) 144 | self.atten_embedding = torch.nn.Linear(embed_dim, atten_embed_dim) 145 | self.atten_output_dim = len(field_dims) * atten_embed_dim 146 | self.has_residual = has_residual 147 | 148 | self.self_attns = torch.nn.ModuleList([ 149 | torch.nn.MultiheadAttention(atten_embed_dim, num_heads, dropout=dropouts[0]) for _ in range(num_layers) 150 | ]) 151 | self.attn_fc = torch.nn.Linear(self.atten_output_dim, 1) 152 | if self.has_residual: 153 | self.V_res_embedding = torch.nn.Linear(embed_dim, atten_embed_dim) 154 | 155 | def forward(self, x): 156 | x_emb = self.embedding(x) 157 | atten_x = self.atten_embedding(x_emb) 158 | 159 | cross_term = atten_x.transpose(0, 1) 160 | 161 | for self_attn in self.self_attns: 162 | cross_term, _ = self_attn(cross_term, cross_term, cross_term) 163 | 164 | cross_term = cross_term.transpose(0, 1) 165 | if self.has_residual: 166 | V_res = self.V_res_embedding(x_emb) 167 | cross_term += V_res 168 | 169 | cross_term = F.relu(cross_term).contiguous().view(-1, self.atten_output_dim) 170 | 171 | pred_y = self.attn_fc(cross_term) + self.linear(x) 172 | return pred_y 173 | 174 | class AutoIntFrn(torch.nn.Module): 175 | 176 | def __init__(self, field_dims, embed_dim, FRN=None, atten_embed_dim=32, num_heads=4, 177 | num_layers=3, dropouts=(0.5, 0.5), has_residual=True): 178 | super().__init__() 179 | self.num_fields = len(field_dims) 180 | self.linear = FeaturesLinear(field_dims) 181 | self.embedding = FeaturesEmbedding(field_dims, embed_dim) 182 | self.atten_embedding = torch.nn.Linear(embed_dim, atten_embed_dim) 183 | self.atten_output_dim = len(field_dims) * atten_embed_dim 184 | self.has_residual = has_residual 185 | 186 | self.self_attns = torch.nn.ModuleList([ 187 | torch.nn.MultiheadAttention(atten_embed_dim, num_heads, dropout=dropouts[0]) for _ in range(num_layers) 188 | ]) 189 | self.attn_fc = torch.nn.Linear(self.atten_output_dim, 1) 190 | if self.has_residual: 191 | self.V_res_embedding = torch.nn.Linear(embed_dim, atten_embed_dim) 192 | 193 | self.frn = FRN 194 | 195 | def forward(self, x): 196 | x_emb = self.embedding(x) 197 | x_emb, _ = self.frn(x_emb) 198 | atten_x = self.atten_embedding(x_emb) 199 | 200 | cross_term = atten_x.transpose(0, 1) 201 | 202 | for self_attn in self.self_attns: 203 | cross_term, _ = self_attn(cross_term, cross_term, cross_term) 204 | 205 | cross_term = cross_term.transpose(0, 1) 206 | if self.has_residual: 207 | V_res = self.V_res_embedding(x_emb) 208 | cross_term += V_res 209 | 210 | cross_term = F.relu(cross_term).contiguous().view(-1, self.atten_output_dim) 211 | 212 | pred_y = self.attn_fc(cross_term) + self.linear(x) 213 | return pred_y 214 | -------------------------------------------------------------------------------- /FRCTR/model_zoo/dcap.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | """ 3 | @project:RefineCTR 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from FRCTR.common import FeaturesEmbedding, FeaturesLinear, MultiLayerPerceptron 11 | 12 | class DeepCrossAttentionalProductNetwork(torch.nn.Module): 13 | 14 | def __init__(self, field_dims, embed_dim, num_heads, num_layers, mlp_dims, dropouts): 15 | super().__init__() 16 | num_fields = len(field_dims) 17 | self.cap = CrossAttentionalProductNetwork(num_fields, embed_dim, num_heads, num_layers, dropouts[0]) 18 | self.embedding = FeaturesEmbedding(field_dims, embed_dim) 19 | self.num_layers = num_layers 20 | self.embed_output_dim = num_fields * embed_dim 21 | self.attn_output_dim = num_layers * num_fields * (num_fields - 1) // 2 22 | self.mlp = MultiLayerPerceptron(self.attn_output_dim + self.embed_output_dim, mlp_dims, dropouts[1]) 23 | 24 | def generate_square_subsequent_mask(self, num_fields): 25 | r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf'). 26 | Unmasked positions are filled with float(0.0). 27 | """ 28 | mask = (torch.triu(torch.ones(num_fields, num_fields)) == 1) 29 | mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) 30 | return mask 31 | 32 | def forward(self, x): 33 | device = x.device 34 | attn_mask = self.generate_square_subsequent_mask(x.size(1)).to(device) 35 | embed_x = self.embedding(x) 36 | cross_term = self.cap(embed_x, attn_mask) 37 | y = torch.cat([embed_x.view(-1, self.embed_output_dim), cross_term], dim=1) 38 | x = self.mlp(y) 39 | return x.squeeze(1) 40 | 41 | class DCAPFrn(torch.nn.Module): 42 | """ 43 | A pytorch implementation of inner/outer Product Neural Network. 44 | Reference: 45 | Y Qu, et al. Product-based Neural Networks for User Response Prediction, 2016. 46 | """ 47 | 48 | def __init__(self, field_dims, embed_dim, num_heads=1, num_layers=3, mlp_dims=(400,400,400), dropouts=(0.5,0.5), FRN=None): 49 | super().__init__() 50 | num_fields = len(field_dims) 51 | self.cap = CrossAttentionalProductNetwork(num_fields, embed_dim, num_heads, num_layers, dropout=dropouts[0]) 52 | self.embedding = FeaturesEmbedding(field_dims, embed_dim) 53 | self.frn = FRN 54 | self.num_layers = num_layers 55 | self.embed_output_dim = num_fields * embed_dim 56 | self.attn_output_dim = num_layers * num_fields * (num_fields - 1) // 2 57 | self.mlp = MultiLayerPerceptron(self.attn_output_dim + self.embed_output_dim, mlp_dims, dropouts[1]) 58 | 59 | def generate_square_subsequent_mask(self, num_fields): 60 | mask = (torch.triu(torch.ones(num_fields, num_fields)) == 1) 61 | mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) 62 | return mask 63 | 64 | 65 | def forward(self, x): 66 | device = x.device 67 | attn_mask = self.generate_square_subsequent_mask(x.size(1)).to(device) 68 | x_emb = self.embedding(x) 69 | x_emb, weight = self.frn(x_emb) 70 | cross_term = self.cap(x_emb, attn_mask) 71 | x_cat = torch.cat([x_emb.view(-1, self.embed_output_dim), cross_term], dim=1) 72 | pred_y = self.mlp(x_cat) 73 | return pred_y 74 | 75 | 76 | class CrossAttentionalProductNetwork(torch.nn.Module): 77 | 78 | def __init__(self, num_fields, embed_dim, num_heads, num_layers, dropout, kernel_type='mat'): 79 | super().__init__() 80 | self.layers = torch.nn.ModuleList([]) 81 | self.layers.extend( 82 | [self.build_encoder_layer(num_fields=num_fields, embed_dim=embed_dim, num_heads=num_heads, 83 | dropout=dropout, kernel_type=kernel_type) for _ in range(num_layers)] 84 | ) 85 | 86 | 87 | def build_encoder_layer(self, num_fields, embed_dim, num_heads, dropout, kernel_type='mat'): 88 | return CrossProductNetwork(num_fields=num_fields, embed_dim=embed_dim, num_heads=num_heads, dropout=dropout, kernel_type=kernel_type) 89 | 90 | def forward(self, x, attn_mask=None): 91 | x0 = x 92 | output = [] 93 | for layer in self.layers: 94 | x, y = layer(x, x0, attn_mask) 95 | output.append(y) 96 | output = torch.cat(output, dim=1) 97 | 98 | return output 99 | 100 | 101 | class CrossProductNetwork(torch.nn.Module): 102 | 103 | def __init__(self, num_fields, embed_dim, num_heads, dropout=0.2, kernel_type='mat'): 104 | super().__init__() 105 | num_ix = num_fields * (num_fields - 1) // 2 106 | if kernel_type == 'mat': 107 | kernel_shape = embed_dim, num_ix, embed_dim 108 | elif kernel_type == 'vec': 109 | kernel_shape = num_ix, embed_dim 110 | elif kernel_type == 'num': 111 | kernel_shape = num_ix, 1 112 | else: 113 | raise ValueError('unknown kernel type: ' + kernel_type) 114 | self.kernel_type = kernel_type 115 | self.kernel = torch.nn.Parameter(torch.zeros(kernel_shape)) 116 | self.avg_pool = torch.nn.AdaptiveAvgPool1d(num_fields) 117 | # self.fc = torch.nn.Linear(embed_dim, 1) 118 | self.attn = MultiheadAttentionInnerProduct(num_fields, embed_dim, num_heads, dropout) 119 | torch.nn.init.xavier_uniform_(self.kernel.data) 120 | 121 | def forward(self, x, x0, attn_mask=None): 122 | 123 | bsz, num_fields, embed_dim = x0.size() 124 | row, col = list(), list() 125 | for i in range(num_fields - 1): 126 | for j in range(i + 1, num_fields): 127 | row.append(i), col.append(j) 128 | 129 | x, _ = self.attn(x, x, x, attn_mask) 130 | p, q = x[:, row], x0[:, col] 131 | if self.kernel_type == 'mat': 132 | kp = torch.sum(p.unsqueeze(1) * self.kernel, dim=-1).permute(0, 2, 1) # (bsz, n(n-1)/2, embed_dim) 133 | kpq = kp * q 134 | 135 | x = self.avg_pool(kpq.permute(0, 2, 1)).permute(0, 2, 1) # (bsz, n, embed_dim) 136 | 137 | return x, torch.sum(kpq, dim=-1) 138 | else: 139 | return torch.sum(p * q * self.kernel.unsqueeze(0), -1) 140 | 141 | 142 | class MultiheadAttentionInnerProduct(torch.nn.Module): 143 | def __init__(self, num_fields, embed_dim, num_heads, dropout): 144 | super().__init__() 145 | self.num_fields = num_fields 146 | self.mask = (torch.triu(torch.ones(num_fields, num_fields), diagonal=1) == 1) 147 | self.num_cross_terms = num_fields * (num_fields - 1) // 2 148 | self.embed_dim = embed_dim 149 | self.num_heads = num_heads 150 | self.dropout_p = dropout 151 | head_dim = embed_dim // num_heads 152 | assert head_dim * num_heads == embed_dim, "head dim is not divisible by embed dim" 153 | self.head_dim = head_dim 154 | self.scale = self.head_dim ** -0.5 155 | 156 | self.linear_q = torch.nn.Linear(embed_dim, num_heads * head_dim, bias=True) 157 | self.linear_k = torch.nn.Linear(embed_dim, num_heads * head_dim, bias=True) 158 | self.avg_pool = torch.nn.AdaptiveAvgPool1d(num_fields) 159 | self.output_layer = torch.nn.Linear(embed_dim, embed_dim, bias=True) 160 | 161 | # self.fc = torch.nn.Linear(embed_dim, 1) 162 | 163 | def forward(self, query, key, value, attn_mask=None, need_weights=False): 164 | bsz, num_fields, embed_dim = query.size() 165 | 166 | q = self.linear_q(query) 167 | q = q.transpose(0, 1).contiguous() 168 | q = q.view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 169 | 1) 170 | q = q * self.scale 171 | k = self.linear_k(key) 172 | k = k.transpose(0, 1).contiguous() 173 | k = k.view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) 174 | v = value.transpose(0, 1).contiguous() 175 | v = v.view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) 176 | 177 | attn_output_weights = torch.bmm(q, k.transpose(1, 2)) 178 | 179 | if attn_mask is not None: 180 | attn_output_weights += attn_mask 181 | 182 | attn_output_weights = torch.softmax( 183 | attn_output_weights, dim=-1) 184 | attn_output_weights = F.dropout(attn_output_weights, self.dropout_p) 185 | 186 | attn_output = torch.bmm(attn_output_weights, v) 187 | assert list(attn_output.size()) == [bsz * self.num_heads, num_fields, self.head_dim] 188 | attn_output = attn_output.transpose(0, 1).contiguous().view(num_fields, bsz, embed_dim).transpose(0, 1) 189 | attn_output = self.output_layer(attn_output) 190 | if need_weights: 191 | attn_output_weights = attn_output_weights.view(bsz, self.num_heads, num_fields, num_fields) 192 | return attn_output, attn_output_weights.sum(dim=0) / bsz 193 | 194 | return attn_output, None 195 | 196 | 197 | def get_activation_fn(activation: str): 198 | """ Returns the activation function corresponding to `activation` """ 199 | if activation == "relu": 200 | return torch.relu 201 | elif activation == "tanh": 202 | return torch.tanh 203 | elif activation == "linear": 204 | return lambda x: x 205 | else: 206 | raise RuntimeError("--activation-fn {} not supported".format(activation)) -------------------------------------------------------------------------------- /FRCTR/model_zoo/dcn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | """ 3 | @project:RefineCTR 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | 8 | from FRCTR.common import FeaturesEmbedding, MultiLayerPerceptron, CrossNetwork, BasicFRCTR 9 | 10 | class CNFrn(BasicFRCTR): 11 | def __init__(self, field_dims, embed_dim, FRN=None, cn_layers=2): 12 | super(CNFrn, self).__init__(field_dims, embed_dim, FRN) 13 | self.embed_output_dim = len(field_dims) * embed_dim 14 | self.cross_net = CrossNetwork(self.embed_output_dim, cn_layers) 15 | self.fc = nn.Linear(self.embed_output_dim, 1) 16 | 17 | def forward(self, x): 18 | x_emb = self.embedding(x) 19 | x_emb, _ = self.frn(x_emb) 20 | x_emb = x_emb.reshape(-1, self.embed_output_dim) 21 | cross_cn = self.cross_net(x_emb) 22 | pred_y = self.fc(cross_cn) 23 | return pred_y 24 | 25 | class DCNFrn(BasicFRCTR): 26 | def __init__(self, field_dims, embed_dim, FRN=None, cn_layers=3, mlp_layers=(400, 400, 400), dropout=0.5): 27 | super(DCNFrn, self).__init__(field_dims, embed_dim, FRN) 28 | 29 | self.embed_output_dim = len(field_dims) * embed_dim 30 | self.cross_net = CrossNetwork(self.embed_output_dim, cn_layers) 31 | self.mlp = MultiLayerPerceptron(self.embed_output_dim, mlp_layers, output_layer=False, dropout=dropout) 32 | self.fc = nn.Linear(mlp_layers[-1] + self.embed_output_dim, 1) 33 | 34 | def forward(self, x): 35 | x_emb = self.embedding(x) 36 | x_emb, _ = self.frn(x_emb) 37 | x_emb = x_emb.reshape(-1, self.embed_output_dim) 38 | 39 | cross_cn = self.cross_net(x_emb) 40 | cross_mlp = self.mlp(x_emb) 41 | pred_y = self.fc(torch.cat([cross_cn, cross_mlp], dim=1)) 42 | return pred_y 43 | 44 | 45 | class DCNFrnP(nn.Module): 46 | def __init__(self, field_dims, embed_dim, FRN1=None, FRN2=None, 47 | cn_layers=3, mlp_layers=(400, 400, 400), dropout=0.5): 48 | super(DCNFrnP, self).__init__() 49 | self.embedding = FeaturesEmbedding(field_dims, embed_dim) 50 | self.embed_output_dim = len(field_dims) * embed_dim 51 | 52 | if not FRN1 or not FRN2: 53 | raise ValueError("Feature Refinement Network is None") 54 | self.frn1 = FRN1 55 | self.frn2 = FRN2 56 | 57 | self.cross_net = CrossNetwork(self.embed_output_dim, cn_layers) 58 | self.mlp = MultiLayerPerceptron(self.embed_output_dim, mlp_layers, output_layer=False, dropout=dropout) 59 | self.fc = nn.Linear(mlp_layers[-1] + self.embed_output_dim, 1) 60 | 61 | def forward(self, x): 62 | x_emb = self.embedding(x) 63 | x_emb1, _ = self.frn1(x_emb) 64 | x_emb2, _ = self.frn2(x_emb) 65 | x_emb1 = x_emb1.reshape(-1, self.embed_output_dim) 66 | x_emb2 = x_emb2.reshape(-1, self.embed_output_dim) 67 | 68 | cross_cn = self.cross_net(x_emb1) 69 | cross_mlp = self.mlp(x_emb2) 70 | pred_y = self.fc(torch.cat([cross_cn, cross_mlp], dim=1)) 71 | return pred_y 72 | 73 | 74 | 75 | class DCN(nn.Module): 76 | def __init__(self, field_dims, embed_dim, cn_layers=3, mlp_layers=(400, 400, 400), dropout=0.5): 77 | super(DCN, self).__init__() 78 | self.embedding = FeaturesEmbedding(field_dims, embed_dim) 79 | self.embed_output_dim = len(field_dims) * embed_dim 80 | 81 | self.cross_net = CrossNetwork(self.embed_output_dim, cn_layers) 82 | self.mlp = MultiLayerPerceptron(self.embed_output_dim, mlp_layers, output_layer=False, dropout=dropout) 83 | self.fc = nn.Linear(mlp_layers[-1] + self.input_dim, 1) 84 | 85 | def forward(self, x): 86 | x_emb = self.embedding(x) 87 | x_emb = self.frn(x_emb) 88 | x_emb = x_emb.reshape(-1, self.embed_output_dim) 89 | 90 | cross_cn = self.cross_net(x_emb) 91 | cross_mlp = self.mlp(x_emb) 92 | pred_y = self.fc(torch.cat([cross_cn, cross_mlp], dim=1)) 93 | return pred_y 94 | -------------------------------------------------------------------------------- /FRCTR/model_zoo/dcnv2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | """ 3 | @project:RefineCTR 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | from FRCTR.common import FeaturesEmbedding, MultiLayerPerceptron, CrossNetworkV2, BasicFRCTR 10 | 11 | class CrossNetV2(torch.nn.Module): 12 | def __init__(self, field_dims, embed_dim, cn_layers=3): 13 | super(CrossNetV2, self).__init__() 14 | self.embedding = FeaturesEmbedding(field_dims, embed_dim) 15 | self.embed_output_dim = len(field_dims) * embed_dim 16 | self.cross_net = CrossNetworkV2(self.embed_output_dim, cn_layers) 17 | self.pred_layer = torch.nn.Linear(self.embed_output_dim, 1) 18 | 19 | def forward(self, x): 20 | x_embed = self.embedding(x).view(-1, self.embed_output_dim) 21 | cross_cn = self.cross_net(x_embed) 22 | pred_y = self.pred_layer(cross_cn) 23 | return pred_y 24 | 25 | 26 | class CN2Frn(BasicFRCTR): 27 | def __init__(self, field_dims, embed_dim, FRN=None, cn_layers=3): 28 | super(CN2Frn, self).__init__(field_dims, embed_dim, FRN) 29 | self.embed_output_dim = len(field_dims) * embed_dim 30 | self.cross_net = CrossNetworkV2(self.embed_output_dim, cn_layers) 31 | self.pred_layer = torch.nn.Linear(self.embed_output_dim, 1) 32 | 33 | def forward(self, x): 34 | x_emb = self.embedding(x) 35 | x_emb, _ = self.frn(x_emb) 36 | x_emb = x_emb.reshape(-1, self.embed_output_dim) 37 | cross_cn = self.cross_net(x_emb) 38 | pred_y = self.pred_layer(cross_cn) 39 | return pred_y 40 | 41 | class DCNV2(nn.Module): 42 | def __init__(self, field_dims, embed_dim, cn_layers=3, mlp_layers=(400, 400, 400), dropout=0.5): 43 | super(DCNV2, self).__init__() 44 | self.embedding = FeaturesEmbedding(field_dims, embed_dim) 45 | self.embed_output_dim = len(field_dims) * embed_dim 46 | self.cross_net = CrossNetworkV2(self.embed_output_dim, cn_layers) 47 | self.mlp = MultiLayerPerceptron(self.embed_output_dim, mlp_layers, output_layer=False, dropout=dropout) 48 | self.fc = torch.nn.Linear(mlp_layers[-1] + self.embed_output_dim, 1) 49 | 50 | def forward(self, x): 51 | x_emb = self.embedding(x).view(-1, self.embed_output_dim) # B,F*E 52 | cross_cn = self.cross_net(x_emb) 53 | cross_mlp = self.mlp(x_emb) 54 | 55 | pred_y = self.fc(torch.cat([cross_cn, cross_mlp], dim=1)) 56 | return pred_y 57 | 58 | 59 | class DCNV2Frn(BasicFRCTR): 60 | def __init__(self, field_dims, embed_dim, FRN=None, cn_layers=3, 61 | mlp_layers=(400, 400, 400), dropout=0.5): 62 | super(DCNV2Frn, self).__init__(field_dims, embed_dim, FRN) 63 | self.embed_output_dim = len(field_dims) * embed_dim 64 | self.cross_net = CrossNetworkV2(self.embed_output_dim, cn_layers) 65 | self.mlp = MultiLayerPerceptron(self.embed_output_dim, mlp_layers, output_layer=False, dropout=dropout) 66 | self.fc = torch.nn.Linear(mlp_layers[-1] + self.embed_output_dim, 1) 67 | 68 | def forward(self, x): 69 | x_emb = self.embedding(x) 70 | x_emb, _ = self.frn(x_emb) 71 | x_emb = x_emb.reshape(-1, self.embed_output_dim) 72 | cross_cn = self.cross_net(x_emb) 73 | cross_mlp = self.mlp(x_emb) 74 | 75 | pred_y = self.fc(torch.cat([cross_cn, cross_mlp], dim=1)) 76 | return pred_y 77 | 78 | 79 | class DCNV2FrnP(nn.Module): 80 | def __init__(self, field_dims, embed_dim, FRN1=None, FRN2=None, 81 | cn_layers=4, mlp_layers=(400, 400, 400), 82 | dropout=0.5): 83 | super(DCNV2FrnP, self).__init__() 84 | self.embedding = FeaturesEmbedding(field_dims, embed_dim) 85 | if not FRN1 or not FRN2: 86 | raise ValueError("Feature Refinement Network is None") 87 | self.frn1 = FRN1 88 | self.frn2 = FRN2 89 | self.embed_output_dim = len(field_dims) * embed_dim 90 | self.cross_net = CrossNetworkV2(self.embed_output_dim, cn_layers) 91 | self.mlp = MultiLayerPerceptron(self.embed_output_dim, mlp_layers, output_layer=False, dropout=dropout) 92 | self.fc = torch.nn.Linear(mlp_layers[-1] + self.embed_output_dim, 1) 93 | 94 | def forward(self, x): 95 | x_emb = self.embedding(x) 96 | x_emb1, _ = self.frn1(x_emb) 97 | x_emb2, _ = self.frn2(x_emb) 98 | x_emb1 = x_emb1.reshape(-1, self.embed_output_dim) 99 | x_emb2 = x_emb2.reshape(-1, self.embed_output_dim) 100 | cross_cn = self.cross_net(x_emb1) 101 | cross_mlp = self.mlp(x_emb2) 102 | 103 | pred_y = self.fc(torch.cat([cross_cn, cross_mlp], dim=1)) 104 | return pred_y 105 | 106 | 107 | class DCNV2S(torch.nn.Module): 108 | def __init__(self, field_dims, embed_dim, cn_layers=3, mlp_layers=(400, 400, 400), dropout=0.5): 109 | super(DCNV2S, self).__init__() 110 | self.embedding = FeaturesEmbedding(field_dims, embed_dim) 111 | 112 | self.embed_output_dim = len(field_dims) * embed_dim 113 | self.cross_net = CrossNetworkV2(self.embed_output_dim, cn_layers) 114 | self.pred_layer = MultiLayerPerceptron(self.embed_output_dim, mlp_layers, output_layer=True, 115 | dropout=dropout) 116 | 117 | def forward(self, x): 118 | x_embed = self.embedding(x).view(-1, self.embed_output_dim) 119 | cross_cn = self.cross_net(x_embed) 120 | pred_y = self.pred_layer(cross_cn) 121 | return pred_y 122 | -------------------------------------------------------------------------------- /FRCTR/model_zoo/deepfm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | """ 3 | @project:RefineCTR 4 | """ 5 | 6 | from torch import nn 7 | from FRCTR.common import BasicFRCTR, FeaturesLinear, FeaturesEmbedding, \ 8 | FactorizationMachine, MultiLayerPerceptron 9 | 10 | 11 | class DeepFM(nn.Module): 12 | def __init__(self, field_dims, embed_dim, mlp_layers=(400, 400, 400), dropout=0.5): 13 | super(DeepFM, self).__init__() 14 | self.lr = FeaturesLinear(field_dims) 15 | self.embedding = FeaturesEmbedding(field_dims, embed_dim) 16 | 17 | self.embed_output_dim = len(field_dims) * embed_dim 18 | self.mlp = MultiLayerPerceptron(self.embed_output_dim, embed_dims=mlp_layers, 19 | dropout=dropout, output_layer=True) 20 | 21 | self.fm = FactorizationMachine(reduce_sum=True) 22 | 23 | def forward(self, x): 24 | x_emb = self.embedding(x) 25 | pred_y = self.lr(x) + self.fm(x_emb) + self.mlp(x_emb.view(x.size(0), -1)) 26 | return pred_y 27 | 28 | 29 | class DeepFMFrn(BasicFRCTR): 30 | def __init__(self, field_dims, embed_dim, FRN=None, mlp_layers=(400, 400, 400), dropout=0.5): 31 | super(DeepFMFrn, self).__init__(field_dims, embed_dim, FRN) 32 | self.lr = FeaturesLinear(field_dims) 33 | 34 | self.embed_output_dim = len(field_dims) * embed_dim 35 | self.mlp = MultiLayerPerceptron(self.embed_output_dim, embed_dims=mlp_layers, 36 | dropout=dropout, output_layer=True) 37 | 38 | self.fm = FactorizationMachine(reduce_sum=True) 39 | 40 | def forward(self, x): 41 | x_emb = self.embedding(x) 42 | x_emb, x_weight = self.frn(x_emb) 43 | pred_y = self.lr(x) + self.fm(x_emb) + self.mlp(x_emb.reshape(x.size(0), -1)) 44 | return pred_y 45 | 46 | class DeepFMFrnP(nn.Module): 47 | """ 48 | DeepFM with two separate feature refinement modules. 49 | """ 50 | def __init__(self, field_dims, embed_dim, FRN1=None, FRN2=None, mlp_layers=(400, 400, 400), dropout=0.5): 51 | super(DeepFMFrnP, self).__init__() 52 | self.lr = FeaturesLinear(field_dims) 53 | self.embedding = FeaturesEmbedding(field_dims, embed_dim) 54 | 55 | if not FRN1 or not FRN2: 56 | raise ValueError("Feature Refinement Network is None") 57 | self.frn1 = FRN1 58 | self.frn2 = FRN2 59 | 60 | self.embed_output_dim = len(field_dims) * embed_dim 61 | self.mlp = MultiLayerPerceptron(self.embed_output_dim, embed_dims=mlp_layers, 62 | dropout=dropout, output_layer=True) 63 | 64 | self.fm = FactorizationMachine(reduce_sum=True) 65 | 66 | def forward(self, x): 67 | x_emb = self.embedding(x) 68 | x_emb1, x_weight1 = self.frn1(x_emb) 69 | x_emb2, x_weight2 = self.frn2(x_emb) 70 | pred_y = self.lr(x) + self.fm(x_emb1) + self.mlp(x_emb2.reshape(x.size(0), -1)) 71 | return pred_y -------------------------------------------------------------------------------- /FRCTR/model_zoo/fed.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | """ 3 | @project:RefineCTR 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from FRCTR.common import FeaturesEmbedding, MultiLayerPerceptron, FactorizationMachine 10 | 11 | class FED(nn.Module): 12 | def __init__(self, field_dims, embed_dim, mlp_layers=(400, 400, 400), drm_flag=True, dropout=0.5): 13 | super(FED, self).__init__() 14 | 15 | self.embedding = FeaturesEmbedding(field_dims=field_dims, embed_dim=embed_dim) 16 | self.drm_flag = drm_flag 17 | self.drm = DRM(len(field_dims)) 18 | 19 | self.mlp_input_dim = len(field_dims) * embed_dim 20 | self.element_wise = MultiLayerPerceptron(self.mlp_input_dim, mlp_layers, output_layer=False) 21 | 22 | self.field_wise = FieldAttentionModule(embed_dim) 23 | 24 | self.out_len = self.mlp_input_dim * 2 + list(mlp_layers)[-1] 25 | 26 | self.fm = FactorizationMachine(reduce_sum=True) 27 | 28 | self.lin_out = nn.Linear(self.out_len, 1) 29 | 30 | def forward(self, x): 31 | b = x.size(0) 32 | E = self.embedding(x) 33 | E = self.drm(E) 34 | 35 | E_f = self.field_wise(E) + E 36 | E_e = self.element_wise(E.reshape(b, -1)) 37 | E_con = torch.cat([E_f.reshape(b, -1), 38 | E_e.reshape(b, -1), 39 | E.reshape(b, -1)], dim=1) 40 | pred_y = self.lin_out(E_con) 41 | return pred_y 42 | 43 | class DRM(nn.Module): 44 | def __init__(self, num_field): 45 | super(DRM, self).__init__() 46 | self.fan = FieldAttentionModule(num_field) 47 | 48 | def forward(self, V): 49 | U = V.permute(0, 2, 1) 50 | E = self.fan(U).permute(0, 2, 1) 51 | E = E + V 52 | return E 53 | 54 | class FieldAttentionModule(nn.Module): 55 | def __init__(self, embed_dim): 56 | super(FieldAttentionModule, self).__init__() 57 | self.trans_Q = nn.Linear(embed_dim, embed_dim) 58 | self.trans_K = nn.Linear(embed_dim, embed_dim) 59 | self.trans_V = nn.Linear(embed_dim, embed_dim) 60 | 61 | def forward(self, x, scale=None): 62 | """ 63 | :param x: B,F,E 64 | :return: B,F,E 65 | """ 66 | Q = self.trans_Q(x) 67 | K = self.trans_K(x) 68 | V = self.trans_V(x) 69 | 70 | attention = torch.matmul(Q, K.permute(0, 2, 1)) 71 | if scale: 72 | attention = attention * scale 73 | attention = F.softmax(attention, dim=-1) 74 | context = torch.matmul(attention, V) 75 | 76 | return context 77 | -------------------------------------------------------------------------------- /FRCTR/model_zoo/fibinet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | """ 3 | @project:RefineCTR 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | from FRCTR.common import FeaturesEmbedding, MultiLayerPerceptron, \ 10 | FeaturesLinear, SenetLayer, BilinearInteractionLayer 11 | 12 | 13 | class FiBiNet(nn.Module): 14 | def __init__(self, field_dims, embed_dim, mlp_layers=(400, 400, 400), dropout=0.5, bilinear_type="all"): 15 | super(FiBiNet, self).__init__() 16 | num_fields = len(field_dims) 17 | self.linear = FeaturesLinear(field_dims) 18 | 19 | self.embedding = FeaturesEmbedding(field_dims, embed_dim) 20 | 21 | self.senet = SenetLayer(num_fields) 22 | 23 | self.bilinear = BilinearInteractionLayer(num_fields, embed_dim, bilinear_type=bilinear_type) 24 | self.bilinear2 = BilinearInteractionLayer(num_fields, embed_dim, bilinear_type=bilinear_type) 25 | 26 | num_inter = num_fields * (num_fields - 1) // 2 27 | self.embed_output_size = num_inter * embed_dim 28 | self.mlp = MultiLayerPerceptron(2 * self.embed_output_size, mlp_layers, dropout=dropout) 29 | 30 | def forward(self, x): 31 | lin = self.linear(x) 32 | x_emb = self.embedding(x) 33 | x_senet, x_weight = self.senet(x_emb) 34 | 35 | x_bi1 = self.bilinear(x_emb) 36 | x_bi2 = self.bilinear2(x_senet) 37 | 38 | x_con = torch.cat([x_bi1.view(x.size(0), -1), 39 | x_bi2.view(x.size(0), -1)], dim=1) 40 | 41 | pred_y = self.mlp(x_con) + lin 42 | return pred_y 43 | 44 | 45 | class FiBiNetFrn(nn.Module): 46 | def __init__(self, field_dims, embed_dim, FRN1=None, FRN2=None, 47 | mlp_layers=(400, 400, 400), dropout=0.5, bilinear_type="all"): 48 | super(FiBiNetFrn, self).__init__() 49 | num_fields = len(field_dims) 50 | self.linear = FeaturesLinear(field_dims) 51 | 52 | self.embedding = FeaturesEmbedding(field_dims, embed_dim) 53 | 54 | if not FRN1 or not FRN2: 55 | raise ValueError("Feature Refinement Network is None") 56 | self.frn1 = FRN1 57 | self.frn2 = FRN2 58 | 59 | self.bilinear = BilinearInteractionLayer(num_fields, embed_dim, bilinear_type=bilinear_type) 60 | self.bilinear2 = BilinearInteractionLayer(num_fields, embed_dim, bilinear_type=bilinear_type) 61 | 62 | num_inter = num_fields * (num_fields - 1) // 2 63 | self.embed_output_size = num_inter * embed_dim 64 | self.mlp = MultiLayerPerceptron(2 * self.embed_output_size, mlp_layers, dropout=dropout) 65 | 66 | def forward(self, x): 67 | lin = self.linear(x) 68 | x_emb = self.embedding(x) 69 | x_emb1, x_weight1 = self.frn1(x_emb) 70 | x_emb2, x_weight2 = self.frn2(x_emb) 71 | 72 | x_bi1 = self.bilinear(x_emb1) 73 | x_bi2 = self.bilinear2(x_emb2) 74 | 75 | x_con = torch.cat([x_bi1.view(x.size(0), -1), 76 | x_bi2.view(x.size(0), -1)], dim=1) 77 | 78 | pred_y = self.mlp(x_con) + lin 79 | return pred_y 80 | 81 | class FiBiNetFrn1(nn.Module): 82 | def __init__(self, field_dims, embed_dim, FRN=None, 83 | mlp_layers=(400, 400, 400), dropout=0.5, bilinear_type="all"): 84 | super(FiBiNetFrn1, self).__init__() 85 | num_fields = len(field_dims) 86 | self.linear = FeaturesLinear(field_dims) 87 | 88 | self.embedding = FeaturesEmbedding(field_dims, embed_dim) 89 | 90 | if not FRN: 91 | raise ValueError("Feature Refinement Network is None") 92 | self.frn = FRN 93 | 94 | self.bilinear = BilinearInteractionLayer(num_fields, embed_dim, bilinear_type=bilinear_type) 95 | self.bilinear2 = BilinearInteractionLayer(num_fields, embed_dim, bilinear_type=bilinear_type) 96 | 97 | num_inter = num_fields * (num_fields - 1) // 2 98 | self.embed_output_size = num_inter * embed_dim 99 | self.mlp = MultiLayerPerceptron(2 * self.embed_output_size, mlp_layers, dropout=dropout) 100 | 101 | def forward(self, x): 102 | lin = self.linear(x) 103 | x_emb = self.embedding(x) 104 | x_emb1, x_weight1 = self.frn(x_emb) 105 | 106 | x_bi1 = self.bilinear(x_emb) 107 | x_bi2 = self.bilinear2(x_emb1) 108 | 109 | x_con = torch.cat([x_bi1.view(x.size(0), -1), 110 | x_bi2.view(x.size(0), -1)], dim=1) 111 | 112 | pred_y = self.mlp(x_con) + lin 113 | return pred_y 114 | -------------------------------------------------------------------------------- /FRCTR/model_zoo/fint.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | """ 3 | @project:RefineCTR 4 | """ 5 | 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from FRCTR.common import FeaturesLinear, FeaturesEmbedding, FactorizationMachine, \ 11 | MultiLayerPerceptron, BasicFRCTR 12 | 13 | class FintLayer(nn.Module): 14 | def forward(self, x_vl, x_wl, x_ul, x_embed): 15 | x_vl = x_vl * (torch.matmul(x_wl, x_embed)) + x_ul * x_vl 16 | return x_vl 17 | 18 | 19 | class FINT(nn.Module): 20 | """ 21 | 1、Embedding layer 22 | 2、Field aware interaction layer。 23 | 3、DNN layer for prediction 24 | """ 25 | 26 | def __init__(self, field_dims, embed_dim, mlp_layers=(400, 400, 400), num_deep=3, dropout=0.5): 27 | super(FINT, self).__init__() 28 | 29 | self.linear = FeaturesLinear(field_dims) 30 | 31 | self.embedding = FeaturesEmbedding(field_dims, embed_dim) 32 | num_fields = len(field_dims) 33 | self.num_deep = num_deep 34 | self.fints_param = nn.ParameterList( 35 | [nn.Parameter(torch.randn(num_fields, num_fields)) for _ in range(num_deep)]) 36 | self.Ul = nn.ParameterList([nn.Parameter(torch.ones(1, num_fields, 1)) for _ in range(num_deep)]) 37 | 38 | self.embed_output_size = num_fields * embed_dim 39 | self.mlp = MultiLayerPerceptron(self.embed_output_size, embed_dims=mlp_layers, dropout=0.5) 40 | 41 | def forward(self, x): 42 | x_emb = self.embedding(x) 43 | x_fint = x_emb 44 | for i in range(self.num_deep): 45 | x_fint = x_fint * torch.matmul(self.fints_param[i], x_emb) + self.Ul[i] * x_fint 46 | 47 | pred_y = self.mlp(x_fint.view(x_fint.size(0), -1)) 48 | return pred_y 49 | 50 | class FINTFrn(BasicFRCTR): 51 | """ 52 | 1、Embedding layer 53 | 2、Field aware interaction layer。 54 | 3、DNN layer for prediction 55 | """ 56 | 57 | def __init__(self, field_dims, embed_dim, FRN=None, mlp_layers=(400, 400, 400), num_deep=3, dropout=0.5): 58 | super(FINTFrn, self).__init__(field_dims, embed_dim, FRN) 59 | 60 | self.linear = FeaturesLinear(field_dims) 61 | num_fields = len(field_dims) 62 | self.num_deep = num_deep 63 | self.fints_param = nn.ParameterList( 64 | [nn.Parameter(torch.randn(num_fields, num_fields)) for _ in range(num_deep)]) 65 | self.Ul = nn.ParameterList([nn.Parameter(torch.ones(1, num_fields, 1)) for _ in range(num_deep)]) 66 | 67 | self.embed_output_size = num_fields * embed_dim 68 | self.mlp = MultiLayerPerceptron(self.embed_output_size, embed_dims=mlp_layers, dropout=0.5) 69 | 70 | def forward(self, x): 71 | x_emb = self.embedding(x) 72 | x_emb, weight = self.frn(x_emb) 73 | x_fint = x_emb 74 | for i in range(self.num_deep): 75 | x_fint = x_fint * torch.matmul(self.fints_param[i], x_emb) + self.Ul[i] * x_fint 76 | 77 | pred_y = self.mlp(x_fint.view(x_fint.size(0), -1)) 78 | return pred_y -------------------------------------------------------------------------------- /FRCTR/model_zoo/fm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | """ 3 | @project:RefineCTR 4 | """ 5 | 6 | import torch.nn as nn 7 | 8 | from FRCTR.common import FeaturesLinear, FeaturesEmbedding, BasicFRCTR, FactorizationMachine 9 | 10 | 11 | class FM(nn.Module): 12 | def __init__(self, field_dims, emb_dim): 13 | super(FM, self).__init__() 14 | self.lr = FeaturesLinear(field_dims) 15 | self.embedding = FeaturesEmbedding(field_dims, emb_dim) 16 | self.fm = FactorizationMachine(reduce_sum=True) 17 | 18 | def forward(self, x): 19 | x_emb = self.embedding(x) 20 | pred_y = self.lr(x) + self.fm(x_emb) 21 | return pred_y 22 | 23 | 24 | class FMFrn(BasicFRCTR): 25 | def __init__(self, field_dims, emb_dim, FRN=None): 26 | super().__init__(field_dims, emb_dim, FRN) 27 | self.lr = FeaturesLinear(field_dims) 28 | self.fm = FactorizationMachine(reduce_sum=True) 29 | 30 | def forward(self, x): 31 | x_emb = self.embedding(x) 32 | x_emb, _ = self.frn(x_emb) 33 | pred_y = self.lr(x) + self.fm(x_emb) 34 | return pred_y 35 | -------------------------------------------------------------------------------- /FRCTR/model_zoo/fmfm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | """ 3 | @project:RefineCTR 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | from FRCTR.common import FeaturesLinear, FactorizationMachine, FeaturesEmbedding 10 | 11 | 12 | class FMFM(nn.Module): 13 | def __init__(self, field_dims, embed_dim, interaction_type="matrix"): 14 | super(FMFM, self).__init__() 15 | self.embedding = FeaturesEmbedding(field_dims, embed_dim) 16 | self.lr = FeaturesLinear(field_dims) 17 | self.num_field = len(field_dims) 18 | self.fm = FactorizationMachine(reduce_sum=True) 19 | self.inter_num = self.num_field * (self.num_field - 1) // 2 20 | self.field_interaction_type = interaction_type 21 | if self.field_interaction_type == "vector": # FvFM 22 | # F,1, E 23 | self.interaction_weight = nn.Parameter(torch.Tensor(self.inter_num, embed_dim)) 24 | elif self.field_interaction_type == "matrix": # FmFM 25 | # F,E,E 26 | self.interaction_weight = nn.Parameter(torch.Tensor(self.inter_num, embed_dim, embed_dim)) 27 | nn.init.xavier_uniform_(self.interaction_weight.data) 28 | # self.triu_index = torch.triu(torch.ones(self.num_field, self.num_field), 1).nonzero().cuda() 29 | self.row, self.col = list(), list() 30 | for i in range(self.num_field - 1): 31 | for j in range(i + 1, self.num_field): 32 | self.row.append(i), self.col.append(j) 33 | 34 | def forward(self, x): 35 | x_emb = self.embedding(x) # (B,F, E) 36 | # left_emb = torch.index_select(emb_x, 1, self.triu_index[:, 0]).cuda() #B,F,E 37 | # right_emb = torch.index_select(emb_x, 1, self.triu_index[:, 1]).cuda() # B,F,E 38 | left_emb = x_emb[:, self.row] 39 | right_emb = x_emb[:, self.col] 40 | # Transfer the embedding space of left_emb to corresponding space 41 | if self.field_interaction_type == "vector": 42 | left_emb = left_emb * self.interaction_weight # B,I,E 43 | elif self.field_interaction_type == "matrix": 44 | # B,F,1,E * F,E,E = B,F,1,E => B,F,E 45 | left_emb = torch.matmul(left_emb.unsqueeze(2), self.interaction_weight).squeeze(2) 46 | # FM interaction 47 | pred_y = (left_emb * right_emb).sum(dim=-1).sum(dim=-1, keepdim=True) 48 | pred_y += self.lr(x) 49 | return pred_y 50 | -------------------------------------------------------------------------------- /FRCTR/model_zoo/fnn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | """ 3 | @project:RefineCTR 4 | """ 5 | 6 | import torch.nn as nn 7 | 8 | from FRCTR.common import FeaturesLinear, FeaturesEmbedding, MultiLayerPerceptron, BasicFRCTR 9 | 10 | class FNN(nn.Module): 11 | def __init__(self, field_dims, embed_dim, mlp_layers=(400, 400, 400), dropout=0.5): 12 | super(FNN, self).__init__() 13 | self.lr = FeaturesLinear(field_dims) 14 | self.embedding = FeaturesEmbedding(field_dims, embed_dim) 15 | 16 | self.embed_output_dim = len(field_dims) * embed_dim 17 | self.mlp = MultiLayerPerceptron(self.embed_output_dim, embed_dims=mlp_layers, 18 | dropout=dropout, output_layer=True) 19 | 20 | def forward(self, x): 21 | x_emb = self.embedding(x) 22 | pred_y = self.lr(x) + self.mlp(x_emb.view(x.size(0), -1)) 23 | return pred_y 24 | 25 | class FNNFrn(BasicFRCTR): 26 | def __init__(self, field_dims, embed_dim, FRN=None, 27 | mlp_layers=(400, 400, 400), dropout=0.5): 28 | super(FNN, self).__init__(field_dims, embed_dim, FRN) 29 | self.lr = FeaturesLinear(field_dims) 30 | self.embedding = FeaturesEmbedding(field_dims, embed_dim) 31 | 32 | self.embed_output_dim = len(field_dims) * embed_dim 33 | self.mlp = MultiLayerPerceptron(self.embed_output_dim, embed_dims=mlp_layers, 34 | dropout=dropout, output_layer=True) 35 | 36 | def forward(self, x): 37 | x_emb = self.embedding(x) 38 | x_emb, weight = self.frn(x_emb) 39 | pred_y = self.lr(x) + self.mlp(x_emb.view(x.size(0), -1)) 40 | return pred_y -------------------------------------------------------------------------------- /FRCTR/model_zoo/fwfm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | """ 3 | @project:RefineCTR 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | from FRCTR.common import FeaturesLinear, FeaturesEmbedding, BasicFRCTR 10 | 11 | 12 | class FwFM(nn.Module): 13 | def __init__(self, field_dims, embed_dim): 14 | super(FwFM, self).__init__() 15 | self.lr = FeaturesLinear(field_dims) 16 | self.embedding = FeaturesEmbedding(field_dims, embed_dim) 17 | self.fwfm = FwFMInterLayer(len(field_dims)) 18 | 19 | def forward(self, x): 20 | x_emb = self.embedding(x) 21 | pred_y = self.lr(x) + self.fwfm(x_emb) 22 | return pred_y 23 | 24 | 25 | class FwFMInterLayer(nn.Module): 26 | def __init__(self, num_fields): 27 | super(FwFMInterLayer, self).__init__() 28 | 29 | self.num_fields = num_fields 30 | num_inter = (num_fields * (num_fields - 1)) // 2 31 | 32 | self.fc = nn.Linear(num_inter, 1) 33 | self.row, self.col = list(), list() 34 | for i in range(self.num_fields - 1): 35 | for j in range(i + 1, self.num_fields): 36 | self.row.append(i), self.col.append(j) 37 | 38 | def forward(self, x_embed): 39 | x_inter = torch.sum(x_embed[:, self.row] * x_embed[:, self.col], dim=2, keepdim=False) 40 | inter_sum = self.fc(x_inter) 41 | return inter_sum 42 | 43 | 44 | class FwFMFrn(BasicFRCTR): 45 | def __init__(self, field_dims, embed_dim, FRN=None): 46 | super(FwFMFrn, self).__init__(field_dims, embed_dim, FRN) 47 | self.lr = FeaturesLinear(field_dims) 48 | self.fwfm = FwFMInterLayer(len(field_dims)) 49 | 50 | def forward(self, x): 51 | x_emb = self.embedding(x) 52 | x_emb, weight = self.frn(x_emb) 53 | pred_y = self.lr(x) + self.fwfm(x_emb) 54 | return pred_y 55 | -------------------------------------------------------------------------------- /FRCTR/model_zoo/lr.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | """ 3 | @project:RefineCTR 4 | """ 5 | 6 | import torch.nn as nn 7 | 8 | from FRCTR.common import FeaturesLinear 9 | 10 | 11 | class LogisticRegression(nn.Module): 12 | def __init__(self, field_dims): 13 | super(LogisticRegression, self).__init__() 14 | self.linear = FeaturesLinear(field_dims) 15 | 16 | def forward(self, x): 17 | pred_y = self.linear(x) 18 | return pred_y 19 | -------------------------------------------------------------------------------- /FRCTR/model_zoo/nfm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | """ 3 | @project:RefineCTR 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | from FRCTR.common import FeaturesLinear, FeaturesEmbedding, FactorizationMachine, MultiLayerPerceptron, BasicFRCTR 10 | 11 | 12 | class NeuralFactorizationMachineModel(nn.Module): 13 | def __init__(self, field_dims, embed_dim, mlp_dims, dropouts): 14 | super().__init__() 15 | self.embedding = FeaturesEmbedding(field_dims, embed_dim) 16 | self.linear = FeaturesLinear(field_dims) 17 | self.fm = torch.nn.Sequential( 18 | FactorizationMachine(reduce_sum=False), 19 | torch.nn.BatchNorm1d(embed_dim), 20 | torch.nn.Dropout(dropouts[0]) 21 | ) 22 | self.mlp = MultiLayerPerceptron(embed_dim, mlp_dims, dropouts[1]) 23 | 24 | def forward(self, x): 25 | x_emb = self.embedding(x) 26 | cross_term = self.fm(x_emb) 27 | pred_y = self.linear(x) + self.mlp(cross_term) 28 | return pred_y 29 | 30 | 31 | class NFMFrn(BasicFRCTR): 32 | def __init__(self, field_dims, embed_dim, FRN=None, mlp_layers=(400, 400, 400), dropouts=(0.5, 0.5)): 33 | super().__init__(field_dims, embed_dim, FRN) 34 | self.linear = FeaturesLinear(field_dims) 35 | self.fm = torch.nn.Sequential( 36 | FactorizationMachine(reduce_sum=False), 37 | torch.nn.BatchNorm1d(embed_dim), 38 | torch.nn.Dropout(dropouts[0]) 39 | ) 40 | self.mlp = MultiLayerPerceptron(embed_dim, mlp_layers, dropouts[1]) 41 | 42 | def forward(self, x): 43 | x_emb = self.embedding(x) 44 | x_emb, _ = self.frn(x_emb) 45 | cross_term = self.fm(x_emb) 46 | pred_y = self.linear(x) + self.mlp(cross_term) 47 | return pred_y 48 | -------------------------------------------------------------------------------- /FRCTR/model_zoo/pnn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | """ 3 | @project:RefineCTR 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | from FRCTR.common import FeaturesEmbedding, InnerProductNetwork, OuterProductNetwork, MultiLayerPerceptron, BasicFRCTR 10 | from common import OuterProductNetwork2 11 | 12 | 13 | class IPNN(nn.Module): 14 | def __init__(self, field_dims, embed_dim, mlp_layers=(400, 400, 400), dropout=0.5): 15 | super(IPNN, self).__init__() 16 | self.embedding = FeaturesEmbedding(field_dims, embed_dim) 17 | num_fields = len(field_dims) 18 | self.pnn = InnerProductNetwork(num_fields) 19 | 20 | self.embed_output_dim = num_fields * embed_dim 21 | self.inter_size = num_fields * (num_fields - 1) // 2 22 | self.mlp = MultiLayerPerceptron(self.inter_size + self.input_dim, mlp_layers, dropout=dropout) 23 | 24 | def forward(self, x): 25 | # B,F,E 26 | x_emb = self.embedding(x) 27 | cross_ipnn = self.pnn(x_emb) 28 | 29 | x = torch.cat([x_emb.view(-1, self.embed_output_dim), cross_ipnn], dim=1) 30 | pred_y = self.mlp(x) 31 | return pred_y 32 | 33 | 34 | class IPNNFrn(BasicFRCTR): 35 | def __init__(self, field_dims, embed_dim, FRN=None, mlp_layers=(400, 400, 400), dropout=0.5): 36 | super(IPNNFrn, self).__init__(field_dims, embed_dim, FRN) 37 | num_fields = len(field_dims) 38 | self.pnn = InnerProductNetwork(num_fields) 39 | self.embed_output_dim = num_fields * embed_dim 40 | self.inter_size = num_fields * (num_fields - 1) // 2 41 | self.mlp = MultiLayerPerceptron(self.inter_size + self.input_dim, mlp_layers, dropout=dropout) 42 | 43 | def forward(self, x): 44 | x_emb = self.embedding(x) 45 | x_emb, weight = self.frn(x_emb) 46 | cross_ipnn = self.pnn(x_emb) 47 | 48 | x = torch.cat([x_emb.view(-1, self.embed_output_dim), cross_ipnn], dim=1) 49 | pred_y = self.mlp(x) 50 | return pred_y 51 | 52 | 53 | class OPNN(nn.Module): 54 | def __init__(self, field_dims, embed_dim, mlp_layers=(400, 400, 400), dropout=0.5, kernel_type="vec"): 55 | super(OPNN, self).__init__() 56 | self.embedding = FeaturesEmbedding(field_dims, embed_dim) 57 | num_fields = len(field_dims) 58 | if kernel_type == "original": 59 | self.pnn = OuterProductNetwork2(num_fields, embed_dim) 60 | else: 61 | self.pnn = OuterProductNetwork(num_fields, embed_dim, kernel_type) 62 | 63 | self.embed_output_dim = num_fields * embed_dim 64 | self.inter_size = num_fields * (num_fields - 1) // 2 65 | self.mlp = MultiLayerPerceptron(self.inter_size + self.embed_output_dim, mlp_layers, dropout) 66 | 67 | def forward(self, x): 68 | x_emb = self.embedding(x) 69 | cross_opnn = self.pnn(x_emb) 70 | 71 | x = torch.cat([x_emb.view(-1, self.embed_output_dim), cross_opnn], dim=1) 72 | pred_y = self.mlp(x) 73 | return pred_y 74 | 75 | 76 | class OPNNFrn(BasicFRCTR): 77 | def __init__(self, field_dims, embed_dim, FRN=None, mlp_layers=(400, 400, 400), 78 | dropout=0.5, kernel_type="vec"): 79 | super(OPNNFrn, self).__init__(field_dims, embed_dim, FRN) 80 | num_fields = len(field_dims) 81 | if kernel_type == "original": 82 | self.pnn = OuterProductNetwork2(num_fields, embed_dim) 83 | else: 84 | self.pnn = OuterProductNetwork(num_fields, embed_dim, kernel_type) 85 | 86 | self.embed_output_dim = num_fields * embed_dim 87 | self.inter_size = num_fields * (num_fields - 1) // 2 88 | self.mlp = MultiLayerPerceptron(self.inter_size + self.embed_output_dim, mlp_layers, dropout) 89 | 90 | def forward(self, x): 91 | x_emb = self.embedding(x) 92 | x_emb, weight = self.frn(x_emb) 93 | cross_opnn = self.pnn(x_emb) 94 | 95 | x = torch.cat([x_emb.view(-1, self.embed_output_dim), cross_opnn], dim=1) 96 | pred_y = self.mlp(x) 97 | return pred_y 98 | -------------------------------------------------------------------------------- /FRCTR/model_zoo/xdeepfm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | """ 3 | @project:RefineCTR 4 | """ 5 | import torch.nn as nn 6 | 7 | from FRCTR.common import CompressedInteractionNetwork, FeaturesEmbedding, \ 8 | FeaturesLinear, MultiLayerPerceptron, BasicFRCTR 9 | 10 | 11 | class CIN(nn.Module): 12 | def __init__(self, field_dims, embed_dim, cross_layer_sizes=(100, 100), split_half=False): 13 | super().__init__() 14 | self.embedding = FeaturesEmbedding(field_dims, embed_dim) 15 | self.cin = CompressedInteractionNetwork(len(field_dims), cross_layer_sizes, split_half) 16 | 17 | def forward(self, x): 18 | """ 19 | param x: Long tensor of size ``(batch_size, num_fields)`` 20 | """ 21 | x_emb = self.embedding(x) 22 | pred_y = self.cin(x_emb) 23 | return pred_y 24 | 25 | 26 | class xDeepFM(nn.Module): 27 | def __init__(self, field_dims, embed_dim, mlp_dims=(400, 400, 400), 28 | dropout=0.5, cross_layer_sizes=(100, 100), split_half=True): 29 | super().__init__() 30 | self.embedding = FeaturesEmbedding(field_dims, embed_dim) 31 | self.embed_output_dim = len(field_dims) * embed_dim 32 | self.cin = CompressedInteractionNetwork(len(field_dims), cross_layer_sizes, split_half) 33 | self.mlp = MultiLayerPerceptron(self.embed_output_dim, mlp_dims, dropout) 34 | self.linear = FeaturesLinear(field_dims) 35 | 36 | def forward(self, x): 37 | x_emb = self.embedding(x) 38 | 39 | cin_term = self.cin(x_emb) 40 | mlp_term = self.mlp(x_emb.view(-1, self.embed_output_dim)) 41 | 42 | pred_y = self.linear(x) + cin_term + mlp_term 43 | return pred_y 44 | 45 | 46 | class CINFrn(BasicFRCTR): 47 | def __init__(self, field_dims, embed_dim, FRN=None, 48 | cross_layer_sizes=(100, 100), split_half=False): 49 | super().__init__(field_dims, embed_dim, FRN) 50 | self.cin = CompressedInteractionNetwork(len(field_dims), cross_layer_sizes, split_half) 51 | 52 | def forward(self, x): 53 | x_emb = self.embedding(x) 54 | x_emb, _ = self.frn(x_emb) 55 | pred_y = self.cin(x_emb) 56 | return pred_y 57 | 58 | 59 | class xDeepFMFrn(BasicFRCTR): 60 | def __init__(self, field_dims, embed_dim, FRN=None, mlp_dims=(400, 400, 400), dropout=0.5, 61 | cross_layer_sizes=(100, 100), split_half=True): 62 | super().__init__(field_dims, embed_dim, FRN) 63 | self.embedding = FeaturesEmbedding(field_dims, embed_dim) 64 | self.embed_output_dim = len(field_dims) * embed_dim 65 | self.cin = CompressedInteractionNetwork(len(field_dims), cross_layer_sizes, split_half) 66 | self.mlp = MultiLayerPerceptron(self.embed_output_dim, mlp_dims, dropout) 67 | self.linear = FeaturesLinear(field_dims) 68 | self.frn = FRN 69 | 70 | def forward(self, x): 71 | x_emb = self.embedding(x) 72 | x_emb, _ = self.frn(x_emb) 73 | cin_term = self.cin(x_emb) 74 | mlp_term = self.mlp(x_emb.view(-1, self.embed_output_dim)) 75 | 76 | pred_y = self.linear(x) + cin_term + mlp_term 77 | return pred_y 78 | 79 | 80 | class xDeepFMFrnP(nn.Module): 81 | def __init__(self, field_dims, embed_dim, FRN1=None, FRN2=None, mlp_dims=(400, 400, 400), dropout=0.5, 82 | cross_layer_sizes=(100, 100), split_half=True): 83 | super().__init__() 84 | self.embedding = FeaturesEmbedding(field_dims, embed_dim) 85 | if not FRN1 or not FRN2: 86 | raise ValueError("Feature Refinement Network is None") 87 | self.frn1 = FRN1 88 | self.frn2 = FRN2 89 | self.embed_output_dim = len(field_dims) * embed_dim 90 | self.cin = CompressedInteractionNetwork(len(field_dims), cross_layer_sizes, split_half) 91 | self.mlp = MultiLayerPerceptron(self.embed_output_dim, mlp_dims, dropout) 92 | self.linear = FeaturesLinear(field_dims) 93 | 94 | def forward(self, x): 95 | x_emb = self.embedding(x) 96 | x_emb1, _ = self.frn1(x_emb) 97 | x_emb2, _ = self.frn2(x_emb) 98 | cin_term = self.cin(x_emb1) 99 | mlp_term = self.mlp(x_emb2.view(-1, self.embed_output_dim)) 100 | 101 | pred_y = self.linear(x) + cin_term + mlp_term 102 | return pred_y 103 | -------------------------------------------------------------------------------- /FRCTR/module_zoo/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | """ 3 | @project: RefineCTR 4 | """ 5 | 6 | import torch 7 | 8 | # import sys 9 | # sys.path.append("../") 10 | 11 | from .skip import Skip 12 | from .contextnet import TCELayer, PFFNLayer 13 | from .dfen import DualFENLayer 14 | from .fen import FENLayer 15 | from .drm import DRMLayer 16 | from .frnet import FRNetLayer 17 | from .fwn import FWNLayer 18 | from .gatenet import GateLayer 19 | from .selfatt import InterCTRLayer 20 | from .senet import SenetLayer 21 | from .gfrllayer import GFRLLayer 22 | 23 | ALLFrn_OPS = { 24 | "skip": lambda field_length, embed_dim: Skip(), # Skip-connection 25 | "senet": lambda field_length, embed_dim: SenetLayer(field_length, ratio=2), 26 | "fen": lambda field_length, embed_dim: FENLayer(field_length, embed_dim, mlp_layers=[256, 256, 256]), 27 | "non": lambda field_length, embed_dim: FWNLayer(field_length, embed_dim), 28 | "drm": lambda field_length, embed_dim: DRMLayer(field_length), 29 | "dfen": lambda field_length, embed_dim: DualFENLayer(field_length, embed_dim, att_size=embed_dim, 30 | num_heads=4, embed_dims=[256, 256, 256]), 31 | "gate_v": lambda field_length, embed_dim: GateLayer(field_length, embed_dim, gate_type="vec"), 32 | "gate_b": lambda field_length, embed_dim: GateLayer(field_length, embed_dim, gate_type="bit"), 33 | "pffn": lambda field_length, embed_dim: PFFNLayer(field_length, embed_dim, project_dim=32, num_blocks=3), 34 | "tce": lambda field_length, embed_dim: TCELayer(field_length, embed_dim, project_dim=2*embed_dim), 35 | "gfrl": lambda field_length, embed_dim: GFRLLayer(field_length, embed_dim, dnn_size=[256]), 36 | "frnet_v": lambda field_length, embed_dim: FRNetLayer(field_length, embed_dim, weight_type="vec", 37 | num_layers=1, att_size=16, mlp_layer=128), 38 | "frnet_b": lambda field_length, embed_dim: FRNetLayer(field_length, embed_dim, weight_type="bit", 39 | num_layers=1, att_size=16, mlp_layer=128), 40 | "selfatt": lambda field_length, embed_dim: InterCTRLayer(embed_dim, att_size=16, 41 | num_heads=8, out_dim=embed_dim) 42 | } 43 | 44 | if __name__ == '__main__': 45 | inputs = torch.randn(10, 20, 16) 46 | names = ["skip", "drm","non","senet","fen", "dfen","selfatt", "frnet_b", "frnet_v", "gfrl", "tce", "pffn", "gate_b", "gate_v"] 47 | for index, name in enumerate(names): 48 | frn = ALLFrn_OPS[name](20, 16) 49 | out, weight = frn(inputs) 50 | print("index:{}, frn:{}, size:{}".format(index+1, name, out.size())) 51 | -------------------------------------------------------------------------------- /FRCTR/module_zoo/contextnet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | """ 3 | @project:RefineCTR 4 | """ 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn 8 | 9 | 10 | class TCELayer(nn.Module): 11 | def __init__(self, field_length, embed_dim, project_dim=32, agg_share=False): 12 | super(TCELayer, self).__init__() 13 | # sharing parameter in aggregation layer 14 | input_dim = field_length * embed_dim 15 | if agg_share: 16 | # individual parameter for each feature(field) 17 | self.aggregation_w = nn.Parameter(torch.randn(1, input_dim, project_dim)) 18 | else: 19 | # share nothing 20 | self.aggregation_w = nn.Parameter(torch.randn(field_length, input_dim, project_dim)) 21 | 22 | self.field_length = field_length 23 | self.project_w = nn.Parameter(torch.randn(field_length, project_dim, embed_dim)) 24 | nn.init.xavier_uniform_(self.project_w.data) 25 | nn.init.xavier_uniform_(self.aggregation_w.data) 26 | 27 | def forward(self, x_emb): 28 | x_cat = x_emb.view(x_emb.size(0), -1).unsqueeze(0).expand(self.field_length, -1, -1) # F,B,F*E 29 | x_agg = torch.relu(torch.matmul(x_cat, self.aggregation_w)) # F, B, P 30 | x_project = torch.matmul(x_agg, self.project_w).permute(1, 0, 2) # FBP FPE = FBE => B,F,E 31 | x_emb = x_emb * torch.relu(x_project) 32 | return x_emb, x_project 33 | 34 | 35 | class PFFNLayer(nn.Module): 36 | def __init__(self, field_length, embed_dim, project_dim=32, 37 | agg_share=False, num_blocks=3): 38 | super(PFFNLayer, self).__init__() 39 | self.tce_layer = TCELayer(field_length, embed_dim, project_dim, agg_share=agg_share) 40 | # Do not share any parameter in Point-wise FFN: 41 | self.W1 = nn.Parameter(torch.randn(field_length, embed_dim, embed_dim)) 42 | 43 | # Sharing the parameters 44 | # self.W1 = nn.Parameter(torch.randn(1, embed_dim, embed_dim)) 45 | # self.W2 = nn.Parameter(torch.randn(1, embed_dim, embed_dim)) 46 | self.LN = nn.LayerNorm(embed_dim) 47 | self.num_blocks = num_blocks 48 | nn.init.xavier_uniform_(self.W1.data) 49 | # nn.init.xavier_uniform_(self.W2.data) 50 | 51 | def forward(self, x_emb): 52 | x_emb = self.tce_layer(x_emb)[0] 53 | for _ in range(self.num_blocks): 54 | x_emb_ = torch.matmul(x_emb.permute(1, 0, 2), self.W1) # F,B,E 55 | x_emb = self.LN(x_emb_.permute(1, 0, 2) + x_emb) 56 | # x_emb_ = torch.relu(torch.matmul(x_emb, self.W1)) # F,B,E 57 | # x_emb_ = torch.matmul(x_emb_, self.W2) 58 | # x_emb = self.LN(x_emb_ + x_emb) # ,B,F,E 59 | return x_emb, None 60 | 61 | 62 | class SFFN(nn.Module): 63 | def __init__(self, field_length, embed_dim, project_dim=32, 64 | agg_share=False, num_blocks=3): 65 | super(SFFN, self).__init__() 66 | self.tce_layer = TCELayer(field_length, embed_dim, project_dim, agg_share=agg_share) 67 | self.W1 = nn.Parameter(torch.randn(1, embed_dim, embed_dim)) 68 | self.LN = nn.LayerNorm(embed_dim) 69 | self.num_blocks = num_blocks 70 | nn.init.xavier_uniform_(self.W1.data) 71 | 72 | def forward(self, x_emb): 73 | x_emb = self.tce_layer(x_emb)[0] 74 | for _ in range(self.num_blocks): 75 | x_emb = self.LN(torch.matmul(x_emb, self.W1)) 76 | return x_emb, None 77 | -------------------------------------------------------------------------------- /FRCTR/module_zoo/dfen.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | """ 3 | @project:RefineCTR 4 | """ 5 | 6 | import torch 7 | from torch import nn 8 | 9 | 10 | class DualFENLayer(nn.Module): 11 | def __init__(self, field_length, embed_dim, embed_dims=[256, 256, 256], att_size=64, num_heads=8): 12 | super(DualFENLayer, self).__init__() 13 | input_dim = field_length * embed_dim # 10*256 14 | self.mlp = MultiLayerPerceptron(input_dim, embed_dims, dropout=0.5, output_layer=False) 15 | 16 | self.multihead = MultiHeadAttentionL(model_dim=embed_dim, dk=att_size, num_heads=num_heads) 17 | self.trans_vec_size = att_size * num_heads * field_length 18 | self.trans_vec = nn.Linear(self.trans_vec_size, field_length, bias=False) 19 | self.trans_bit = nn.Linear(embed_dims[-1], field_length, bias=False) 20 | 21 | def forward(self, x_emb): 22 | # (1) concat 23 | x_con = x_emb.view(x_emb.size(0), -1) # [B, ?] 24 | 25 | # (2)bit-level difm does not apply softmax or sigmoid 26 | m_bit = self.mlp(x_con) 27 | 28 | # (3)vector-level multi-head 29 | x_att2 = self.multihead(x_emb, x_emb, x_emb) 30 | m_vec = self.trans_vec(x_att2.view(-1, self.trans_vec_size)) 31 | m_bit = self.trans_bit(m_bit) 32 | 33 | x_att = m_bit + m_vec 34 | x_emb = x_emb * x_att.unsqueeze(2) 35 | return x_emb, x_att 36 | 37 | 38 | class MultiHeadAttentionL(nn.Module): 39 | def __init__(self, model_dim=256, dk=32, num_heads=16): 40 | super(MultiHeadAttentionL, self).__init__() 41 | 42 | self.dim_per_head = dk # dk dv 43 | self.num_heads = num_heads 44 | 45 | self.linear_k = nn.Linear(model_dim, self.dim_per_head * num_heads) 46 | self.linear_v = nn.Linear(model_dim, self.dim_per_head * num_heads) 47 | self.linear_q = nn.Linear(model_dim, self.dim_per_head * num_heads) 48 | 49 | self.linear_residual = nn.Linear(model_dim, self.dim_per_head * num_heads) 50 | 51 | # self.layer_norm = nn.LayerNorm(model_dim) # LayerNorm 52 | 53 | def _dot_product_attention(self, q, k, v, scale=None): 54 | attention = torch.bmm(q, k.transpose(1, 2)) * scale 55 | attention = torch.softmax(attention, dim=2) 56 | attention = torch.dropout(attention, p=0.0, train=self.training) 57 | context = torch.bmm(attention, v) 58 | return context, attention 59 | 60 | def forward(self, key0, value0, query0, attn_mask=None): 61 | batch_size = key0.size(0) 62 | 63 | key = self.linear_k(key0) 64 | value = self.linear_v(value0) 65 | query = self.linear_q(query0) 66 | 67 | key = key.view(batch_size * self.num_heads, -1, self.dim_per_head) 68 | value = value.view(batch_size * self.num_heads, -1, self.dim_per_head) 69 | query = query.view(batch_size * self.num_heads, -1, self.dim_per_head) 70 | 71 | scale = (key.size(-1) // self.num_heads) ** -0.5 72 | context, attention = self._dot_product_attention(query, key, value, scale) 73 | context = context.view(batch_size, -1, self.dim_per_head * self.num_heads) 74 | 75 | residual = self.linear_residual(query0) 76 | residual = residual.view(batch_size, -1, self.dim_per_head * self.num_heads) # [B, 10, 256*h] 77 | 78 | output = torch.relu(residual + context) # [B, 10, 256] 79 | return output 80 | 81 | 82 | class MultiLayerPerceptron(nn.Module): 83 | def __init__(self, input_dim, embed_dims, dropout=0.5, output_layer=True): 84 | super().__init__() 85 | layers = list() 86 | for embed_dim in embed_dims: 87 | layers.append(torch.nn.Linear(input_dim, embed_dim)) 88 | layers.append(torch.nn.BatchNorm1d(embed_dim)) 89 | layers.append(torch.nn.ReLU()) 90 | layers.append(torch.nn.Dropout(p=dropout)) 91 | input_dim = embed_dim 92 | 93 | if output_layer: 94 | layers.append(torch.nn.Linear(input_dim, 1)) 95 | self.mlp = torch.nn.Sequential(*layers) 96 | self._init_weight_() 97 | 98 | def _init_weight_(self): 99 | for m in self.mlp: 100 | if isinstance(m, nn.Linear): 101 | nn.init.xavier_uniform_(m.weight) 102 | 103 | def forward(self, x): 104 | return self.mlp(x) 105 | -------------------------------------------------------------------------------- /FRCTR/module_zoo/drm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | """ 3 | @project:RefineCTR 4 | """ 5 | import torch 6 | from torch import nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class DRMLayer(nn.Module): 11 | def __init__(self, field_length, scale=None): 12 | super(DRMLayer, self).__init__() 13 | self.trans_Q = nn.Linear(field_length, field_length) 14 | self.trans_K = nn.Linear(field_length, field_length) 15 | self.trans_V = nn.Linear(field_length, field_length) 16 | self.scale = scale 17 | 18 | def _field_attention(self, x_trans): 19 | Q = self.trans_Q(x_trans) 20 | K = self.trans_K(x_trans) 21 | V = self.trans_V(x_trans) 22 | 23 | attention = torch.matmul(Q, K.permute(0, 2, 1)) 24 | if self.scale: 25 | attention = attention * self.scale 26 | attention = F.softmax(attention, dim=-1) 27 | context = torch.matmul(attention, V) 28 | return context, attention 29 | 30 | def forward(self, x_emb): 31 | X_trans = x_emb.permute(0, 2, 1) # B,E,F 32 | X_trans, att_score = self._field_attention(X_trans) 33 | X_trans = X_trans.permute(0, 2, 1) + x_emb # B,F,E 34 | return X_trans.contiguous(), att_score 35 | -------------------------------------------------------------------------------- /FRCTR/module_zoo/fal.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | """ 3 | @project:RefineCTR 4 | """ 5 | import torch 6 | from torch import nn -------------------------------------------------------------------------------- /FRCTR/module_zoo/fen.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | """ 3 | @project:RefineCTR 4 | """ 5 | import torch 6 | from torch import nn 7 | 8 | class FENLayer(nn.Module): 9 | def __init__(self, field_length, embed_dim, mlp_layers=[256, 256, 256], h=1): 10 | super(FENLayer, self).__init__() 11 | self.h = h 12 | self.num_fields = field_length 13 | mlp_layers.append(self.num_fields) 14 | self.mlp_input_dim = self.num_fields * embed_dim 15 | self.mlp = MultiLayerPerceptron(self.mlp_input_dim, mlp_layers, dropout=0.5, output_layer=False) 16 | # self.lin_weight = nn.Linear(256,embed_dim,bias=False) 17 | 18 | def forward(self, x_emb): 19 | x_con = x_emb.view(-1, self.mlp_input_dim) # B,F*E 20 | x_con = self.mlp(x_con) # B,1 21 | x_weight = torch.softmax(x_con, dim=1) * self.h # B,F 22 | x_emb_weight = x_emb * x_weight.unsqueeze(2) 23 | return x_emb_weight, x_weight 24 | 25 | class MultiLayerPerceptron(nn.Module): 26 | def __init__(self, input_dim, embed_dims, dropout=0.5, output_layer=True): 27 | super().__init__() 28 | layers = list() 29 | for embed_dim in embed_dims: 30 | layers.append(torch.nn.Linear(input_dim, embed_dim)) 31 | layers.append(torch.nn.BatchNorm1d(embed_dim)) 32 | layers.append(torch.nn.ReLU()) 33 | layers.append(torch.nn.Dropout(p=dropout)) 34 | input_dim = embed_dim 35 | 36 | if output_layer: 37 | layers.append(torch.nn.Linear(input_dim, 1)) 38 | self.mlp = torch.nn.Sequential(*layers) 39 | self._init_weight_() 40 | 41 | def _init_weight_(self): 42 | for m in self.mlp: 43 | if isinstance(m, nn.Linear): 44 | nn.init.xavier_uniform_(m.weight) 45 | 46 | def forward(self, x): 47 | return self.mlp(x) -------------------------------------------------------------------------------- /FRCTR/module_zoo/frnet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | """ 3 | @project:RefineCTR 4 | """ 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn 8 | 9 | 10 | class FRNetLayer(nn.Module): 11 | def __init__(self, field_length, embed_dim, weight_type="bit", 12 | num_layers=1, att_size=10, mlp_layer=256): 13 | super(FRNetLayer, self).__init__() 14 | self.IEU_G = IEU(field_length, embed_dim, weight_type="bit", 15 | bit_layers=num_layers, att_size=att_size, mlp_layer=mlp_layer) 16 | 17 | # bit-level or vector-level weights. 18 | self.IEU_W = IEU(field_length, embed_dim, weight_type=weight_type, 19 | bit_layers=num_layers, att_size=att_size, mlp_layer=mlp_layer) 20 | 21 | def forward(self, x_embed): 22 | weight_matrix = torch.sigmoid(self.IEU_W(x_embed)) 23 | com_feature = self.IEU_G(x_embed) 24 | # CSGate 25 | x_out = x_embed * weight_matrix + com_feature * (torch.tensor(1.0) - weight_matrix) 26 | return x_out, weight_matrix 27 | 28 | 29 | class IEU(nn.Module): 30 | def __init__(self, field_length, embed_dim, weight_type="bit", 31 | bit_layers=1, att_size=10, mlp_layer=256): 32 | super(IEU, self).__init__() 33 | self.input_dim = field_length * embed_dim 34 | self.weight_type = weight_type 35 | 36 | # Self-attention unit, which is used to capture cross-feature relationships. 37 | self.vector_info = SelfAttentionIEU(embed_dim=embed_dim, att_size=att_size) 38 | 39 | # contextual information extractor(CIE), FRNet adopt MLP to encode contextual information. 40 | mlp_layers = [mlp_layer for _ in range(bit_layers)] 41 | self.mlps = MultiLayerPerceptronPrelu(self.input_dim, embed_dims=mlp_layers, 42 | output_layer=False) 43 | self.bit_projection = nn.Linear(mlp_layer, embed_dim) 44 | # self.activation = nn.ReLU() 45 | self.activation = nn.PReLU() 46 | 47 | def forward(self, x_emb): 48 | # (1)Self-attetnion unit 49 | x_vector = self.vector_info(x_emb) # B,F,E 50 | 51 | # (2) CIE unit 52 | x_bit = self.mlps(x_emb.view(-1, self.input_dim)) 53 | x_bit = self.bit_projection(x_bit).unsqueeze(1) # B,1,e 54 | x_bit = self.activation(x_bit) 55 | 56 | # (3)integration unit 57 | x_out = x_bit * x_vector 58 | 59 | if self.weight_type == "vector": 60 | # To compute vector-level importance in IEU_W 61 | x_out = torch.sum(x_out, dim=2, keepdim=True) 62 | return x_out # B,F,1 63 | 64 | return x_out # B,F,E 65 | 66 | 67 | class SelfAttentionIEU(nn.Module): 68 | def __init__(self, embed_dim, att_size=20): 69 | super(SelfAttentionIEU, self).__init__() 70 | self.embed_dim = embed_dim 71 | self.trans_Q = nn.Linear(embed_dim, att_size) 72 | self.trans_K = nn.Linear(embed_dim, att_size) 73 | self.trans_V = nn.Linear(embed_dim, att_size) 74 | self.projection = nn.Linear(att_size, embed_dim) 75 | # self.scale = embed_dim.size(-1) ** -0.5 76 | 77 | def forward(self, x, scale=None): 78 | Q = self.trans_Q(x) 79 | K = self.trans_K(x) 80 | V = self.trans_V(x) 81 | 82 | attention = torch.matmul(Q, K.permute(0, 2, 1)) # B,F,F 83 | attention_score = F.softmax(attention, dim=-1) 84 | context = torch.matmul(attention_score, V) 85 | context = self.projection(context) 86 | return context 87 | 88 | 89 | class MultiLayerPerceptronPrelu(torch.nn.Module): 90 | def __init__(self, input_dim, embed_dims, dropout=0.5, output_layer=True): 91 | super().__init__() 92 | layers = list() 93 | for embed_dim in embed_dims: 94 | layers.append(torch.nn.Linear(input_dim, embed_dim)) 95 | layers.append(torch.nn.BatchNorm1d(embed_dim)) 96 | layers.append(torch.nn.PReLU()) 97 | layers.append(torch.nn.Dropout(p=dropout)) 98 | input_dim = embed_dim 99 | 100 | if output_layer: 101 | layers.append(torch.nn.Linear(input_dim, 1)) 102 | self.mlp = torch.nn.Sequential(*layers) 103 | self._init_weight_() 104 | 105 | def _init_weight_(self): 106 | for m in self.mlp: 107 | if isinstance(m, nn.Linear): 108 | nn.init.xavier_uniform_(m.weight) 109 | 110 | def forward(self, x): 111 | return self.mlp(x) 112 | -------------------------------------------------------------------------------- /FRCTR/module_zoo/fwn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | """ 3 | @project:RefineCTR 4 | """ 5 | 6 | import torch 7 | 8 | from torch import nn 9 | 10 | 11 | class FWNLayer(nn.Module): 12 | # Also known as field-wise network: FWN 13 | def __init__(self, field_length, embed_dim): 14 | super(FWNLayer, self).__init__() 15 | self.input_dim = field_length * embed_dim 16 | self.local_w = nn.Parameter(torch.randn(field_length, embed_dim, embed_dim)) 17 | self.local_b = nn.Parameter(torch.randn(field_length, 1, embed_dim)) 18 | 19 | nn.init.xavier_uniform_(self.local_w.data) 20 | nn.init.xavier_uniform_(self.local_b.data) 21 | 22 | def forward(self, x_emb): 23 | x_local = torch.matmul(x_emb.permute(1, 0, 2), self.local_w) + self.local_b 24 | x_local0 = torch.relu(x_local).permute(1, 0, 2) 25 | x_local = x_local0 * x_emb 26 | return x_local.contiguous(), x_local0 27 | -------------------------------------------------------------------------------- /FRCTR/module_zoo/gatenet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | """ 3 | @project:RefineCTR 4 | """ 5 | 6 | import torch 7 | from torch import nn 8 | 9 | 10 | class GateLayer(nn.Module): 11 | def __init__(self, field_length, embed_dim, gate_type="vec"): 12 | super(GateLayer, self).__init__() 13 | if gate_type == "bit": 14 | self.local_w = nn.Parameter(torch.randn(field_length, embed_dim, embed_dim)) 15 | elif gate_type == "vec": 16 | self.local_w = nn.Parameter(torch.randn(field_length, embed_dim, 1)) 17 | nn.init.xavier_uniform_(self.local_w.data) 18 | 19 | def forward(self, x_emb): 20 | x_weight = torch.matmul(x_emb.permute(1, 0, 2), self.local_w) 21 | x_weight = x_weight.permute(1, 0, 2) 22 | x_emb_weight = x_weight * x_emb 23 | return x_emb_weight.contiguous(), x_weight 24 | -------------------------------------------------------------------------------- /FRCTR/module_zoo/gfrl.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | """ 3 | @project:RefineCTR 4 | """ 5 | import torch 6 | from torch import nn 7 | 8 | class GFRLLayer(nn.Module): 9 | def __init__(self, field_length, embed_dim, dnn_size=[256]): 10 | super(GFRLLayer, self).__init__() 11 | self.flu1 = FLU(field_length, embed_dim, dnn_size=dnn_size) 12 | self.flu2 = FLU(field_length, embed_dim, dnn_size=dnn_size) 13 | 14 | def forward(self, x_emb): 15 | x_out = self.flu1(x_emb) 16 | x_pro = torch.sigmoid(self.flu2(x_emb)) 17 | x_out = x_emb * (torch.tensor(1.0) - x_pro) + x_out * x_pro 18 | return x_out, x_pro 19 | 20 | 21 | class FLU(nn.Module): 22 | def __init__(self, field_length, embed_dim, dnn_size=[256]): 23 | super(FLU, self).__init__() 24 | self.input_dim = field_length * embed_dim 25 | self.local_w = nn.Parameter(torch.randn(field_length, embed_dim, embed_dim)) 26 | self.local_b = nn.Parameter(torch.randn(field_length, 1, embed_dim)) 27 | 28 | self.mlps = MultiLayerPerceptron(self.input_dim, embed_dims=dnn_size, output_layer=False) 29 | self.bit_info = nn.Linear(dnn_size[-1], embed_dim) 30 | self.acti = nn.ReLU() 31 | 32 | nn.init.xavier_uniform_(self.local_w.data) 33 | nn.init.xavier_uniform_(self.local_b.data) 34 | 35 | def forward(self, x_emb): 36 | x_local = torch.matmul(x_emb.permute(1, 0, 2), self.local_w) + self.local_b 37 | x_local = x_local.permute(1, 0, 2) # B,F,E 38 | 39 | x_glo = self.mlps(x_emb.view(-1, self.input_dim)) 40 | x_glo = self.acti(self.bit_info(x_glo)).unsqueeze(1) # B, E 41 | x_out = x_local * x_glo 42 | return x_out 43 | 44 | class MultiLayerPerceptron(nn.Module): 45 | def __init__(self, input_dim, embed_dims, dropout=0.5, output_layer=True): 46 | super().__init__() 47 | layers = list() 48 | for embed_dim in embed_dims: 49 | layers.append(torch.nn.Linear(input_dim, embed_dim)) 50 | layers.append(torch.nn.BatchNorm1d(embed_dim)) 51 | layers.append(torch.nn.ReLU()) 52 | layers.append(torch.nn.Dropout(p=dropout)) 53 | input_dim = embed_dim 54 | 55 | if output_layer: 56 | layers.append(torch.nn.Linear(input_dim, 1)) 57 | self.mlp = torch.nn.Sequential(*layers) 58 | self._init_weight_() 59 | 60 | def _init_weight_(self): 61 | for m in self.mlp: 62 | if isinstance(m, nn.Linear): 63 | nn.init.xavier_uniform_(m.weight) 64 | 65 | def forward(self, x): 66 | return self.mlp(x) -------------------------------------------------------------------------------- /FRCTR/module_zoo/selfatt.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | """ 3 | @project:RefineCTR 4 | """ 5 | import torch 6 | from torch import nn 7 | 8 | 9 | class InterCTRLayer(nn.Module): 10 | def __init__(self, embed_dim=16, att_size=32, num_heads=8, out_dim=32, use_res=False): 11 | super(InterCTRLayer, self).__init__() 12 | self.use_res = use_res 13 | self.dim_per_head = att_size 14 | self.num_heads = num_heads 15 | 16 | self.linear_k = nn.Linear(embed_dim, self.dim_per_head * num_heads) 17 | self.linear_v = nn.Linear(embed_dim, self.dim_per_head * num_heads) 18 | self.linear_q = nn.Linear(embed_dim, self.dim_per_head * num_heads) 19 | 20 | self.outputw = torch.nn.Linear(self.dim_per_head * num_heads, out_dim, bias=False) 21 | if self.use_res: 22 | # self.linear_residual = nn.Linear(model_dim, self.dim_per_head * num_heads) 23 | self.linear_residual = nn.Linear(embed_dim, out_dim) 24 | nn.init.xavier_uniform_(self.linear_q.weight) 25 | nn.init.xavier_uniform_(self.linear_k.weight) 26 | nn.init.xavier_uniform_(self.linear_v.weight) 27 | nn.init.xavier_uniform_(self.outputw.weight) 28 | 29 | def _dot_product_attention(self, q, k, v): 30 | attention = torch.bmm(q, k.transpose(1, 2)) 31 | 32 | attention = torch.softmax(attention, dim=2) 33 | 34 | attention = torch.dropout(attention, p=0.0, train=self.training) 35 | context = torch.bmm(attention, v) 36 | return context, attention 37 | 38 | def forward(self, query): 39 | batch_size = query.size(0) 40 | key = self.linear_k(query) 41 | value = self.linear_v(query) 42 | query = self.linear_q(query) 43 | 44 | key = key.view(batch_size * self.num_heads, -1, self.dim_per_head) 45 | value = value.view(batch_size * self.num_heads, -1, self.dim_per_head) 46 | query = query.view(batch_size * self.num_heads, -1, self.dim_per_head) 47 | 48 | context, attention = self._dot_product_attention(query, key, value) # [B*16, 10, 256] 49 | context = context.view(batch_size, -1, self.dim_per_head * self.num_heads) # [B, 10, 256*16] 50 | context = torch.relu(self.outputw(context)) # B, F, out_dim 51 | return context, attention 52 | -------------------------------------------------------------------------------- /FRCTR/module_zoo/senet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | """ 3 | @project:RefineCTR 4 | """ 5 | import torch 6 | 7 | from torch import nn 8 | 9 | 10 | class SenetLayer(nn.Module): 11 | def __init__(self, field_length, ratio=1): 12 | super(SenetLayer, self).__init__() 13 | self.temp_dim = max(1, field_length // ratio) 14 | self.excitation = nn.Sequential( 15 | nn.Linear(field_length, self.temp_dim), 16 | nn.ReLU(), 17 | nn.Linear(self.temp_dim, field_length), 18 | nn.ReLU() 19 | ) 20 | 21 | def forward(self, x_emb): 22 | """ 23 | (1) Squeeze: max or mean 24 | (2) Excitation 25 | (3) Re-weight 26 | """ 27 | Z_mean = torch.max(x_emb, dim=2, keepdim=True)[0].transpose(1, 2) 28 | # Z_mean = torch.mean(x_emb, dim=2, keepdim=True).transpose(1, 2) 29 | A_weight = self.excitation(Z_mean).transpose(1, 2) 30 | V_embed = torch.mul(A_weight, x_emb) 31 | return V_embed, A_weight 32 | -------------------------------------------------------------------------------- /FRCTR/module_zoo/skip.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | """ 3 | @project:RefineCTR 4 | """ 5 | import torch.nn as nn 6 | 7 | class Skip(nn.Module): 8 | def forward(self, x_emb): 9 | return x_emb, None -------------------------------------------------------------------------------- /FRCTR/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | """ 3 | @project: RefineCTR 4 | """ 5 | 6 | from .util import count_params, setup_seed 7 | from .earlystoping import EarlyStopping, EarlyStoppingLoss -------------------------------------------------------------------------------- /FRCTR/utils/auc.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | """ 3 | @project: RefineCTR 4 | """ 5 | import numpy as np 6 | from sklearn.metrics import roc_auc_score 7 | 8 | 9 | def get_auc(y_labels, y_scores): 10 | auc = roc_auc_score(y_labels, y_scores) 11 | print('AUC calculated by sklearn tool is {}'.format(auc)) 12 | return auc 13 | 14 | 15 | def calculate_auc_func1(y_labels, y_scores): 16 | pos_sample_ids = [i for i in range(len(y_labels)) if y_labels[i] == 1] 17 | neg_sample_ids = [i for i in range(len(y_labels)) if y_labels[i] == 0] 18 | 19 | sum_indicator_value = 0 20 | for i in pos_sample_ids: 21 | for j in neg_sample_ids: 22 | if y_scores[i] > y_scores[j]: 23 | sum_indicator_value += 1 24 | elif y_scores[i] == y_scores[j]: 25 | sum_indicator_value += 0.5 26 | 27 | auc = sum_indicator_value / (len(pos_sample_ids) * len(neg_sample_ids)) 28 | print('AUC calculated by function1 is {:.2f}'.format(auc)) 29 | return auc 30 | 31 | 32 | def calculate_auc_func2(y_labels, y_scores): 33 | samples = list(zip(y_scores, y_labels)) 34 | print(samples) 35 | rank = [(values2, values1) for values1, values2 in sorted(samples, key=lambda x: x[0])] 36 | print(rank) 37 | pos_rank = [i + 1 for i in range(len(rank)) if rank[i][0] == 1] 38 | print(pos_rank) 39 | pos_cnt = np.sum(y_labels == 1) 40 | neg_cnt = np.sum(y_labels == 0) 41 | auc = (np.sum(pos_rank) - pos_cnt * (pos_cnt + 1) / 2) / (pos_cnt * neg_cnt) 42 | print('AUC calculated by function2 is {:.2f}'.format(auc)) 43 | return auc 44 | 45 | 46 | if __name__ == '__main__': 47 | y_labels = np.array([1, 1, 0, 0, 0]) 48 | y_scores = np.array([1, 0.8, 0.2, 0.4, 0.5]) 49 | calculate_auc_func2(y_labels, y_scores) 50 | print(roc_auc_score(y_labels, y_scores)) 51 | -------------------------------------------------------------------------------- /FRCTR/utils/earlystoping.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | ''' 4 | @project:RefineCTR 5 | ''' 6 | 7 | import numpy as np 8 | import torch 9 | 10 | class EarlyStopping: 11 | """Early stops the training if validation loss doesn't improve after a given patience.""" 12 | 13 | def __init__(self, patience=7, verbose=False, delta=0, prefix=None): 14 | """ 15 | Args: 16 | patience (int): How long to wait after last time validation loss improved. 17 | Default: 7 18 | verbose (bool): If True, prints a message for each validation loss improvement. 19 | Default: False 20 | delta (float): Minimum change in the monitored quantity to qualify as an improvement. 21 | Default: 0 22 | """ 23 | self.patience = patience 24 | self.verbose = verbose 25 | self.counter = 0 26 | self.best_score = None 27 | self.early_stop = False 28 | self.val_loss_min = np.Inf 29 | self.delta = delta 30 | self.prefix_path = prefix 31 | 32 | # def __call__(self, val_loss): 33 | def __call__(self, val_auc): 34 | 35 | score = val_auc 36 | 37 | if self.best_score is None: 38 | self.best_score = score 39 | 40 | elif score <= self.best_score + self.delta: 41 | # auc with <; loss with > 42 | self.counter += 1 43 | print(f'EarlyStopping counter: {self.counter} out of {self.patience}') 44 | print("Now auc:{}\tBest_auc:{}".format(val_auc, self.best_score)) 45 | if self.counter >= self.patience: 46 | self.early_stop = True 47 | else: 48 | self.best_score = score 49 | # self.save_checkpoint(val_loss, model) 50 | self.counter = 0 51 | 52 | def save_checkpoint(self, val_loss, model): 53 | '''Saves model when validation loss decrease.''' 54 | if self.verbose: 55 | print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...') 56 | torch.save(model.state_dict(), self.prefix_path + '/es_checkpoint.pt') # 这里会存储迄今最优模型的参数 57 | self.val_loss_min = val_loss 58 | 59 | 60 | class EarlyStoppingLoss: 61 | """Early stops the training if validation loss doesn't improve after a given patience.""" 62 | 63 | def __init__(self, patience=7, verbose=False, delta=0, prefix=None): 64 | """ 65 | Args: 66 | patience (int): How long to wait after last time validation loss improved. 67 | Default: 7 68 | verbose (bool): If True, prints a message for each validation loss improvement. 69 | Default: False 70 | delta (float): Minimum change in the monitored quantity to qualify as an improvement. 71 | Default: 0 72 | """ 73 | self.patience = patience 74 | self.verbose = verbose 75 | self.counter = 0 76 | self.best_score = None 77 | self.early_stop = False 78 | self.val_loss_min = np.Inf 79 | self.delta = delta 80 | self.prefix_path = prefix 81 | 82 | def __call__(self, val_loss): 83 | 84 | score = val_loss 85 | 86 | if self.best_score is None: 87 | self.best_score = score 88 | 89 | elif score > self.best_score + self.delta: 90 | self.counter += 1 91 | print(f'EarlyStopping Loss counter: {self.counter} out of {self.patience}') 92 | print("Now loss:{}\tBest_loss:{}".format(val_loss, self.best_score)) 93 | if self.counter >= self.patience: 94 | self.early_stop = True 95 | else: 96 | self.best_score = score 97 | # self.save_checkpoint(val_loss, model) 98 | self.counter = 0 99 | 100 | def save_checkpoint(self, val_loss, model): 101 | ''' Saves model when validation loss decrease. ''' 102 | if self.verbose: 103 | print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...') 104 | torch.save(model.state_dict(), self.prefix_path + '/es_checkpoint.pt') 105 | self.val_loss_min = val_loss 106 | -------------------------------------------------------------------------------- /FRCTR/utils/util.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | """ 3 | @project: RefineCTR 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import random 10 | import numpy as np 11 | import os 12 | 13 | 14 | def count_params(model): 15 | params = sum(param.numel() for param in model.parameters()) 16 | return params 17 | 18 | 19 | def setup_seed(seed=2022): 20 | os.environ['PYTHONHASHSEED'] = str(seed) 21 | 22 | torch.manual_seed(seed) 23 | torch.cuda.manual_seed(seed) 24 | torch.cuda.manual_seed_all(seed) 25 | 26 | np.random.seed(seed) 27 | random.seed(seed) 28 | 29 | torch.backends.cudnn.deterministic = True 30 | torch.backends.cudnn.benchmark = False 31 | torch.backends.cudnn.enabled = False 32 | 33 | def get_device(gpu=-1): 34 | if gpu >= 0 and torch.cuda.is_available(): 35 | device = torch.device("cuda:" + str(gpu)) 36 | else: 37 | device = torch.device("cpu") 38 | return device 39 | 40 | def get_optimizer(optimizer, params, lr): 41 | if isinstance(optimizer, str): 42 | if optimizer.lower() == "adam": 43 | optimizer = "Adam" 44 | try: 45 | optimizer = getattr(torch.optim, optimizer)(params, lr=lr) 46 | except: 47 | raise NotImplementedError("optimizer={} is not supported.".format(optimizer)) 48 | return optimizer 49 | 50 | def get_loss_fn(loss): 51 | if isinstance(loss, str): 52 | if loss in ["bce", "binary_crossentropy", "binary_cross_entropy"]: 53 | loss = "binary_cross_entropy" 54 | try: 55 | loss_fn = getattr(torch.functional.F, loss) 56 | except: 57 | try: 58 | from . import losses 59 | loss_fn = getattr(losses, loss) 60 | except: 61 | raise NotImplementedError("loss={} is not supported.".format(loss)) 62 | return loss_fn 63 | 64 | def get_regularizer(reg): 65 | reg_pair = [] # of tuples (p_norm, weight) 66 | if isinstance(reg, float): 67 | reg_pair.append((2, reg)) 68 | elif isinstance(reg, str): 69 | try: 70 | if reg.startswith("l1(") or reg.startswith("l2("): 71 | reg_pair.append((int(reg[1]), float(reg.rstrip(")").split("(")[-1]))) 72 | elif reg.startswith("l1_l2"): 73 | l1_reg, l2_reg = reg.rstrip(")").split("(")[-1].split(",") 74 | reg_pair.append((1, float(l1_reg))) 75 | reg_pair.append((2, float(l2_reg))) 76 | else: 77 | raise NotImplementedError 78 | except: 79 | raise NotImplementedError("regularizer={} is not supported.".format(reg)) 80 | return reg_pair 81 | 82 | def get_activation(activation): 83 | if isinstance(activation, str): 84 | if activation.lower() == "relu": 85 | return nn.ReLU() 86 | elif activation.lower() == "sigmoid": 87 | return nn.Sigmoid() 88 | elif activation.lower() == "tanh": 89 | return nn.Tanh() 90 | else: 91 | return getattr(nn, activation)() 92 | else: 93 | return activation 94 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 |
![]() | ![]() |