├── BGNN.py
├── DataHandler.py
├── MV_Net.py
├── Params.py
├── README.md
├── Utils
├── README.md
└── TimeLogger.py
├── data
└── README.md
├── graph_utils.py
└── main.py
/BGNN.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torch.nn import init
6 | from torch.autograd import Variable
7 | from Params import args
8 |
9 | def to_var(x, requires_grad=True):
10 | if torch.cuda.is_available():
11 | x = x.cuda()
12 | return Variable(x, requires_grad=requires_grad)
13 |
14 |
15 | class myModel(nn.Module):
16 | def __init__(self, userNum, itemNum, behavior, behavior_mats):
17 | super(myModel, self).__init__()
18 |
19 | self.userNum = userNum
20 | self.itemNum = itemNum
21 | self.behavior = behavior
22 | self.behavior_mats = behavior_mats
23 |
24 | self.embedding_dict = self.init_embedding()
25 | self.weight_dict = self.init_weight()
26 | self.gcn = GCN(self.userNum, self.itemNum, self.behavior, self.behavior_mats)
27 |
28 |
29 | def init_embedding(self):
30 |
31 | embedding_dict = {
32 | 'user_embedding': None,
33 | 'item_embedding': None,
34 | 'user_embeddings': None,
35 | 'item_embeddings': None,
36 | }
37 | return embedding_dict
38 |
39 | def init_weight(self):
40 | initializer = nn.init.xavier_uniform_
41 |
42 | weight_dict = nn.ParameterDict({
43 | 'w_self_attention_item': nn.Parameter(initializer(torch.empty([args.hidden_dim, args.hidden_dim]))),
44 | 'w_self_attention_user': nn.Parameter(initializer(torch.empty([args.hidden_dim, args.hidden_dim]))),
45 | 'w_self_attention_cat': nn.Parameter(initializer(torch.empty([args.head_num*args.hidden_dim, args.hidden_dim]))),
46 | 'alpha': nn.Parameter(torch.ones(2)),
47 | })
48 | return weight_dict
49 |
50 |
51 | def forward(self):
52 |
53 | user_embed, item_embed, user_embeds, item_embeds = self.gcn()
54 |
55 |
56 | return user_embed, item_embed, user_embeds, item_embeds
57 |
58 | def para_dict_to_tenser(self, para_dict):
59 |
60 | tensors = []
61 | for beh in para_dict.keys():
62 | tensors.append(para_dict[beh])
63 | tensors = torch.stack(tensors, dim=0)
64 |
65 | return tensors.float()
66 |
67 | def update_params(self, lr_inner, first_order=False, source_params=None, detach=False):
68 | if source_params is not None:
69 | for tgt, src in zip(self.named_parameters(), source_params):
70 | name_t, param_t = tgt
71 | grad = src
72 | if first_order:
73 | grad = to_var(grad.detach().data)
74 | tmp = param_t - lr_inner * grad
75 | self.set_param(self, name_t, tmp)
76 | else:
77 |
78 | for name, param in self.named_parameters()(self):
79 | if not detach:
80 | grad = param.grad
81 | if first_order:
82 | grad = to_var(grad.detach().data)
83 | tmp = param - lr_inner * grad
84 | self.set_param(self, name, tmp)
85 | else:
86 | param = param.detach_()
87 | self.set_param(self, name, param)
88 |
89 |
90 | class GCN(nn.Module):
91 | def __init__(self, userNum, itemNum, behavior, behavior_mats):
92 | super(GCN, self).__init__()
93 | self.userNum = userNum
94 | self.itemNum = itemNum
95 | self.hidden_dim = args.hidden_dim
96 |
97 | self.behavior = behavior
98 | self.behavior_mats = behavior_mats
99 |
100 | self.user_embedding, self.item_embedding = self.init_embedding()
101 |
102 | self.alpha, self.i_concatenation_w, self.u_concatenation_w, self.i_input_w, self.u_input_w = self.init_weight()
103 |
104 | self.sigmoid = torch.nn.Sigmoid()
105 | self.act = torch.nn.PReLU()
106 | self.dropout = torch.nn.Dropout(args.drop_rate)
107 |
108 | self.gnn_layer = eval(args.gnn_layer)
109 | self.layers = nn.ModuleList()
110 | for i in range(0, len(self.gnn_layer)):
111 | self.layers.append(GCNLayer(args.hidden_dim, args.hidden_dim, self.userNum, self.itemNum, self.behavior, self.behavior_mats))
112 |
113 | def init_embedding(self):
114 | user_embedding = torch.nn.Embedding(self.userNum, args.hidden_dim)
115 | item_embedding = torch.nn.Embedding(self.itemNum, args.hidden_dim)
116 | nn.init.xavier_uniform_(user_embedding.weight)
117 | nn.init.xavier_uniform_(item_embedding.weight)
118 |
119 | return user_embedding, item_embedding
120 |
121 | def init_weight(self):
122 | alpha = nn.Parameter(torch.ones(2))
123 | i_concatenation_w = nn.Parameter(torch.Tensor(len(eval(args.gnn_layer))*args.hidden_dim, args.hidden_dim))
124 | u_concatenation_w = nn.Parameter(torch.Tensor(len(eval(args.gnn_layer))*args.hidden_dim, args.hidden_dim))
125 | i_input_w = nn.Parameter(torch.Tensor(args.hidden_dim, args.hidden_dim))
126 | u_input_w = nn.Parameter(torch.Tensor(args.hidden_dim, args.hidden_dim))
127 | init.xavier_uniform_(i_concatenation_w)
128 | init.xavier_uniform_(u_concatenation_w)
129 | init.xavier_uniform_(i_input_w)
130 | init.xavier_uniform_(u_input_w)
131 | # init.xavier_uniform_(alpha)
132 |
133 | return alpha, i_concatenation_w, u_concatenation_w, i_input_w, u_input_w
134 |
135 | def forward(self, user_embedding_input=None, item_embedding_input=None):
136 | all_user_embeddings = []
137 | all_item_embeddings = []
138 | all_user_embeddingss = []
139 | all_item_embeddingss = []
140 |
141 | user_embedding = self.user_embedding.weight
142 | item_embedding = self.item_embedding.weight
143 |
144 | for i, layer in enumerate(self.layers):
145 |
146 | user_embedding, item_embedding, user_embeddings, item_embeddings = layer(user_embedding, item_embedding)
147 |
148 | norm_user_embeddings = F.normalize(user_embedding, p=2, dim=1)
149 | norm_item_embeddings = F.normalize(item_embedding, p=2, dim=1)
150 |
151 | all_user_embeddings.append(user_embedding)
152 | all_item_embeddings.append(item_embedding)
153 | all_user_embeddingss.append(user_embeddings)
154 | all_item_embeddingss.append(item_embeddings)
155 |
156 | user_embedding = torch.cat(all_user_embeddings, dim=1)
157 | item_embedding = torch.cat(all_item_embeddings, dim=1)
158 | user_embeddings = torch.cat(all_user_embeddingss, dim=2)
159 | item_embeddings = torch.cat(all_item_embeddingss, dim=2)
160 |
161 | user_embedding = torch.matmul(user_embedding , self.u_concatenation_w)
162 | item_embedding = torch.matmul(item_embedding , self.i_concatenation_w)
163 | user_embeddings = torch.matmul(user_embeddings , self.u_concatenation_w)
164 | item_embeddings = torch.matmul(item_embeddings , self.i_concatenation_w)
165 |
166 |
167 | return user_embedding, item_embedding, user_embeddings, item_embeddings #[31882, 16], [31882, 16], [4, 31882, 16], [4, 31882, 16]
168 |
169 |
170 | class GCNLayer(nn.Module):
171 | def __init__(self, in_dim, out_dim, userNum, itemNum, behavior, behavior_mats):
172 | super(GCNLayer, self).__init__()
173 |
174 | self.behavior = behavior
175 | self.behavior_mats = behavior_mats
176 |
177 | self.userNum = userNum
178 | self.itemNum = itemNum
179 |
180 | self.act = torch.nn.Sigmoid()
181 | self.i_w = nn.Parameter(torch.Tensor(in_dim, out_dim))
182 | self.u_w = nn.Parameter(torch.Tensor(in_dim, out_dim))
183 | self.ii_w = nn.Parameter(torch.Tensor(in_dim, out_dim))
184 | init.xavier_uniform_(self.i_w)
185 | init.xavier_uniform_(self.u_w)
186 |
187 | def forward(self, user_embedding, item_embedding):
188 |
189 | user_embedding_list = [None]*len(self.behavior)
190 | item_embedding_list = [None]*len(self.behavior)
191 |
192 | for i in range(len(self.behavior)):
193 | user_embedding_list[i] = torch.spmm(self.behavior_mats[i]['A'], item_embedding)
194 | item_embedding_list[i] = torch.spmm(self.behavior_mats[i]['AT'], user_embedding)
195 |
196 |
197 |
198 | user_embeddings = torch.stack(user_embedding_list, dim=0)
199 | item_embeddings = torch.stack(item_embedding_list, dim=0)
200 |
201 |
202 | user_embedding = self.act(torch.matmul(torch.mean(user_embeddings, dim=0), self.u_w))
203 | item_embedding = self.act(torch.matmul(torch.mean(item_embeddings, dim=0), self.i_w))
204 |
205 | user_embeddings = self.act(torch.matmul(user_embeddings, self.u_w))
206 | item_embeddings = self.act(torch.matmul(item_embeddings, self.i_w))
207 |
208 |
209 | return user_embedding, item_embedding, user_embeddings, item_embeddings
210 |
211 | #------------------------------------------------------------------------------------------------------------------------------------------------
212 |
213 |
--------------------------------------------------------------------------------
/DataHandler.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.utils.data as data
3 | import pickle
4 | import numpy as np
5 | import scipy.sparse as sp
6 | from math import ceil
7 | import datetime
8 |
9 | from Params import args
10 | import graph_utils
11 |
12 |
13 |
14 | class RecDataset(data.Dataset):
15 | def __init__(self, data, num_item, train_mat=None, num_ng=1, is_training=True):
16 | super(RecDataset, self).__init__()
17 |
18 | self.data = np.array(data)
19 | self.num_item = num_item
20 | self.train_mat = train_mat
21 | self.is_training = is_training
22 |
23 | def ng_sample(self):
24 | assert self.is_training, 'no need to sampling when testing'
25 | dok_trainMat = self.train_mat.todok()
26 | length = self.data.shape[0]
27 | self.neg_data = np.random.randint(low=0, high=self.num_item, size=length)
28 |
29 | for i in range(length): #
30 | uid = self.data[i][0]
31 | iid = self.neg_data[i]
32 | if (uid, iid) in dok_trainMat:
33 | while (uid, iid) in dok_trainMat:
34 | iid = np.random.randint(low=0, high=self.num_item)
35 | self.neg_data[i] = iid
36 | self.neg_data[i] = iid
37 |
38 | def __len__(self):
39 | return len(self.data)
40 |
41 | def __getitem__(self, idx):
42 | user = self.data[idx][0]
43 | item_i = self.data[idx][1]
44 |
45 | if self.is_training:
46 | neg_data = self.neg_data
47 | item_j = neg_data[idx]
48 | return user, item_i, item_j
49 | else:
50 | return user, item_i
51 |
52 | def getMatrix(self):
53 | pass
54 |
55 | def getAdj(self):
56 | pass
57 |
58 | def sampleLargeGraph(self):
59 |
60 |
61 | def makeMask():
62 | pass
63 |
64 | def updateBdgt():
65 | pass
66 |
67 | def sample():
68 | pass
69 |
70 | def constructData(self):
71 | pass
72 |
73 |
74 |
75 |
76 | class RecDataset_beh(data.Dataset):
77 | def __init__(self, beh, data, num_item, behaviors_data=None, num_ng=1, is_training=True):
78 | super(RecDataset_beh, self).__init__()
79 |
80 | self.data = np.array(data)
81 | self.num_item = num_item
82 | self.is_training = is_training
83 | self.beh = beh
84 | self.behaviors_data = behaviors_data
85 |
86 | self.length = self.data.shape[0]
87 | self.neg_data = [None]*self.length
88 | self.pos_data = [None]*self.length
89 |
90 | def ng_sample(self):
91 | assert self.is_training, 'no need to sampling when testing'
92 |
93 | for i in range(self.length):
94 | self.neg_data[i] = [None]*len(self.beh)
95 | self.pos_data[i] = [None]*len(self.beh)
96 |
97 | for index in range(len(self.beh)):
98 |
99 |
100 | train_u, train_v = self.behaviors_data[index].nonzero()
101 | beh_dok = self.behaviors_data[index].todok()
102 |
103 | set_pos = np.array(list(set(train_v)))
104 |
105 | self.pos_data_index = np.random.choice(set_pos, size=self.length, replace=True, p=None)
106 | self.neg_data_index = np.random.randint(low=0, high=self.num_item, size=self.length)
107 |
108 |
109 | for i in range(self.length): #
110 |
111 | uid = self.data[i][0]
112 | iid_neg = self.neg_data[i][index] = self.neg_data_index[i]
113 | iid_pos = self.pos_data[i][index] = self.pos_data_index[i]
114 |
115 | if (uid, iid_neg) in beh_dok:
116 | while (uid, iid_neg) in beh_dok:
117 | iid_neg = np.random.randint(low=0, high=self.num_item)
118 | self.neg_data[i][index] = iid_neg
119 | self.neg_data[i][index] = iid_neg
120 |
121 | if index == (len(self.beh)-1):
122 | self.pos_data[i][index] = train_v[i]
123 | elif (uid, iid_pos) not in beh_dok:
124 | if len(self.behaviors_data[index][uid].data)==0:
125 | self.pos_data[i][index] = -1
126 | else:
127 | t_array = self.behaviors_data[index][uid].toarray()
128 | pos_index = np.where(t_array!=0)[1]
129 | iid_pos = np.random.choice(pos_index, size = 1, replace=True, p=None)[0]
130 | self.pos_data[i][index] = iid_pos
131 |
132 | def __len__(self):
133 | return len(self.data)
134 |
135 | def __getitem__(self, idx):
136 | user = self.data[idx][0]
137 | item_i = self.pos_data[idx]
138 |
139 | if self.is_training:
140 | item_j = self.neg_data[idx]
141 | return user, item_i, item_j
142 | else:
143 | return user, item_i
144 |
145 |
--------------------------------------------------------------------------------
/MV_Net.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import math
5 | from torch.autograd import Variable
6 | import torch.nn.init as init
7 |
8 | from torch.nn import init
9 | from Params import args
10 |
11 | import numpy as np
12 |
13 |
14 |
15 | def to_var(x, requires_grad=True):
16 | if torch.cuda.is_available():
17 | x = x.cuda()
18 | return Variable(x, requires_grad=requires_grad)
19 |
20 |
21 | class MetaModule(nn.Module):
22 | # adopted from: Adrien Ecoffet https://github.com/AdrienLE
23 | def params(self):
24 | for name, param in self.named_params(self):
25 | yield param
26 |
27 | def named_leaves(self):
28 | return []
29 |
30 | def named_submodules(self):
31 | return []
32 |
33 | def named_params(self, curr_module=None, memo=None, prefix=''):
34 | if memo is None:
35 | memo = set()
36 |
37 | if hasattr(curr_module, 'named_leaves'):
38 | for name, p in curr_module.named_leaves():
39 | if p is not None and p not in memo:
40 | memo.add(p)
41 | yield prefix + ('.' if prefix else '') + name, p
42 | else:
43 | for name, p in curr_module._parameters.items():
44 | if p is not None and p not in memo:
45 | memo.add(p)
46 | yield prefix + ('.' if prefix else '') + name, p
47 |
48 | for mname, module in curr_module.named_children():
49 | submodule_prefix = prefix + ('.' if prefix else '') + mname
50 | for name, p in self.named_params(module, memo, submodule_prefix):
51 | yield name, p
52 |
53 | def update_params(self, lr_inner, first_order=False, source_params=None, detach=False):
54 | if source_params is not None:
55 | for tgt, src in zip(self.named_params(self), source_params):
56 | name_t, param_t = tgt
57 | # name_s, param_s = src
58 | # grad = param_s.grad
59 | # name_s, param_s = src
60 | grad = src
61 | if first_order:
62 | grad = to_var(grad.detach().data)
63 | tmp = param_t - lr_inner * grad
64 | self.set_param(self, name_t, tmp)
65 | else:
66 |
67 | for name, param in self.named_params(self):
68 | if not detach:
69 | grad = param.grad
70 | if first_order:
71 | grad = to_var(grad.detach().data)
72 | tmp = param - lr_inner * grad
73 | self.set_param(self, name, tmp)
74 | else:
75 | param = param.detach_() # https://blog.csdn.net/qq_39709535/article/details/81866686
76 | self.set_param(self, name, param)
77 |
78 | def set_param(self, curr_mod, name, param):
79 | if '.' in name:
80 | n = name.split('.')
81 | module_name = n[0]
82 | rest = '.'.join(n[1:])
83 | for name, mod in curr_mod.named_children():
84 | if module_name == name:
85 | self.set_param(mod, rest, param)
86 | break
87 | else:
88 | setattr(curr_mod, name, param)
89 |
90 | def detach_params(self):
91 | for name, param in self.named_params(self):
92 | self.set_param(self, name, param.detach())
93 |
94 | def copy(self, other, same_var=False):
95 | for name, param in other.named_params():
96 | if not same_var:
97 | param = to_var(param.data.clone(), requires_grad=True)
98 | self.set_param(name, param)
99 |
100 |
101 | class MetaLinear(MetaModule):
102 | def __init__(self, *args, **kwargs):
103 | super().__init__()
104 | ignore = nn.Linear(*args, **kwargs)
105 |
106 | self.register_buffer('weight', to_var(ignore.weight.data, requires_grad=True))
107 | self.register_buffer('bias', to_var(ignore.bias.data, requires_grad=True))
108 |
109 | def forward(self, x):
110 | return F.linear(x, self.weight, self.bias)
111 |
112 | def named_leaves(self):
113 | return [('weight', self.weight), ('bias', self.bias)]
114 |
115 |
116 | class MetaConv2d(MetaModule):
117 | def __init__(self, *args, **kwargs):
118 | super().__init__()
119 | ignore = nn.Conv2d(*args, **kwargs)
120 |
121 | self.in_channels = ignore.in_channels
122 | self.out_channels = ignore.out_channels
123 | self.stride = ignore.stride
124 | self.padding = ignore.padding
125 | self.dilation = ignore.dilation
126 | self.groups = ignore.groups
127 | self.kernel_size = ignore.kernel_size
128 |
129 | self.register_buffer('weight', to_var(ignore.weight.data, requires_grad=True))
130 |
131 | if ignore.bias is not None:
132 | self.register_buffer('bias', to_var(ignore.bias.data, requires_grad=True))
133 | else:
134 | self.register_buffer('bias', None)
135 |
136 | def forward(self, x):
137 | return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
138 |
139 | def named_leaves(self):
140 | return [('weight', self.weight), ('bias', self.bias)]
141 |
142 |
143 | class MetaConvTranspose2d(MetaModule):
144 | def __init__(self, *args, **kwargs):
145 | super().__init__()
146 | ignore = nn.ConvTranspose2d(*args, **kwargs)
147 |
148 | self.stride = ignore.stride
149 | self.padding = ignore.padding
150 | self.dilation = ignore.dilation
151 | self.groups = ignore.groups
152 |
153 | self.register_buffer('weight', to_var(ignore.weight.data, requires_grad=True))
154 |
155 | if ignore.bias is not None:
156 | self.register_buffer('bias', to_var(ignore.bias.data, requires_grad=True))
157 | else:
158 | self.register_buffer('bias', None)
159 |
160 | def forward(self, x, output_size=None):
161 | output_padding = self._output_padding(x, output_size)
162 | return F.conv_transpose2d(x, self.weight, self.bias, self.stride, self.padding,
163 | output_padding, self.groups, self.dilation)
164 |
165 | def named_leaves(self):
166 | return [('weight', self.weight), ('bias', self.bias)]
167 |
168 |
169 | class MetaBatchNorm2d(MetaModule):
170 | def __init__(self, *args, **kwargs):
171 | super().__init__()
172 | ignore = nn.BatchNorm2d(*args, **kwargs)
173 |
174 | self.num_features = ignore.num_features
175 | self.eps = ignore.eps
176 | self.momentum = ignore.momentum
177 | self.affine = ignore.affine
178 | self.track_running_stats = ignore.track_running_stats
179 |
180 | if self.affine:
181 | self.register_buffer('weight', to_var(ignore.weight.data, requires_grad=True))
182 | self.register_buffer('bias', to_var(ignore.bias.data, requires_grad=True))
183 |
184 | if self.track_running_stats:
185 | self.register_buffer('running_mean', torch.zeros(self.num_features))
186 | self.register_buffer('running_var', torch.ones(self.num_features))
187 | else:
188 | self.register_parameter('running_mean', None)
189 | self.register_parameter('running_var', None)
190 |
191 | def forward(self, x):
192 | return F.batch_norm(x, self.running_mean, self.running_var, self.weight, self.bias,
193 | self.training or not self.track_running_stats, self.momentum, self.eps)
194 |
195 | def named_leaves(self):
196 | return [('weight', self.weight), ('bias', self.bias)]
197 |
198 |
199 | def _weights_init(m):
200 | classname = m.__class__.__name__
201 | # print(classname)
202 | if isinstance(m, MetaLinear) or isinstance(m, MetaConv2d):
203 | init.kaiming_normal(m.weight)
204 |
205 | class LambdaLayer(MetaModule):
206 | def __init__(self, lambd):
207 | super(LambdaLayer, self).__init__()
208 | self.lambd = lambd
209 |
210 | def forward(self, x):
211 | return self.lambd(x)
212 |
213 |
214 | class BasicBlock(MetaModule):
215 | expansion = 1
216 |
217 | def __init__(self, in_planes, planes, stride=1, option='A'):
218 | super(BasicBlock, self).__init__()
219 | self.conv1 = MetaConv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
220 | self.bn1 = MetaBatchNorm2d(planes)
221 | self.conv2 = MetaConv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
222 | self.bn2 = MetaBatchNorm2d(planes)
223 |
224 | self.shortcut = nn.Sequential()
225 | if stride != 1 or in_planes != planes:
226 | if option == 'A':
227 | """
228 | For CIFAR10 ResNet paper uses option A.
229 | """
230 | self.shortcut = LambdaLayer(lambda x:
231 | F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0))
232 | elif option == 'B':
233 | self.shortcut = nn.Sequential(
234 | MetaConv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
235 | MetaBatchNorm2d(self.expansion * planes)
236 | )
237 |
238 | def forward(self, x):
239 | out = F.relu(self.bn1(self.conv1(x)))
240 | out = self.bn2(self.conv2(out))
241 | out += self.shortcut(x)
242 | out = F.relu(out)
243 | return out
244 |
245 |
246 | class ResNet32(MetaModule):
247 | def __init__(self, num_classes, block=BasicBlock, num_blocks=[5, 5, 5]):
248 | super(ResNet32, self).__init__()
249 | self.in_planes = 16
250 |
251 | self.conv1 = MetaConv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
252 | self.bn1 = MetaBatchNorm2d(16)
253 | self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1)
254 | self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2)
255 | self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2)
256 | self.linear = MetaLinear(64, num_classes)
257 |
258 | self.apply(_weights_init)
259 |
260 | def _make_layer(self, block, planes, num_blocks, stride):
261 | strides = [stride] + [1]*(num_blocks-1)
262 | layers = []
263 | for stride in strides:
264 | layers.append(block(self.in_planes, planes, stride))
265 | self.in_planes = planes * block.expansion
266 |
267 | return nn.Sequential(*layers)
268 |
269 | def forward(self, x):
270 | out = F.relu(self.bn1(self.conv1(x)))
271 | out = self.layer1(out)
272 | out = self.layer2(out)
273 | out = self.layer3(out)
274 | out = F.avg_pool2d(out, out.size()[3])
275 | out = out.view(out.size(0), -1)
276 | out = self.linear(out)
277 | return out
278 |
279 |
280 |
281 | class VNet(MetaModule):
282 | def __init__(self, input, hidden1, output):
283 | super(VNet, self).__init__()
284 | self.linear1 = MetaLinear(input, hidden1)
285 | self.relu1 = nn.ReLU(inplace=True)
286 | self.linear2 = MetaLinear(hidden1, output)
287 | # self.linear3 = MetaLinear(hidden2, output)
288 |
289 | def forward(self, x):
290 | x = self.linear1(x)
291 | x = self.relu1(x)
292 | # x = self.linear2(x)
293 | # x = self.relu1(x)
294 | out = self.linear2(x)
295 | return F.sigmoid(out)
296 |
297 |
298 | class MetaWeightNet(nn.Module):
299 | def __init__(self, beh_num):
300 | super(MetaWeightNet, self).__init__()
301 |
302 | self.beh_num = beh_num
303 |
304 | self.sigmoid = torch.nn.Sigmoid()
305 | self.act = torch.nn.LeakyReLU(negative_slope=args.slope)
306 | self.prelu = torch.nn.PReLU()
307 | self.relu = torch.nn.ReLU()
308 | self.tanhshrink = torch.nn.Tanhshrink()
309 | self.dropout7 = torch.nn.Dropout(args.drop_rate)
310 | self.dropout5 = torch.nn.Dropout(args.drop_rate1)
311 | self.batch_norm = torch.nn.BatchNorm1d(1)
312 |
313 | initializer = nn.init.xavier_uniform_
314 |
315 |
316 | self.SSL_layer1 = nn.Linear(args.hidden_dim*3, int((args.hidden_dim*3)/2))
317 | self.SSL_layer2 = nn.Linear(int((args.hidden_dim*3)/2), 1)
318 | self.SSL_layer3 = nn.Linear(args.hidden_dim*2, 1)
319 |
320 | self.RS_layer1 = nn.Linear(args.hidden_dim*3, int((args.hidden_dim*3)/2))
321 | self.RS_layer2 = nn.Linear(int((args.hidden_dim*3)/2), 1)
322 | self.RS_layer3 = nn.Linear(args.hidden_dim, 1)
323 |
324 |
325 |
326 | self.beh_embedding = nn.Parameter(initializer(torch.empty([beh_num, args.hidden_dim]))).cuda()
327 |
328 |
329 | def forward(self, infoNCELoss_list, behavior_loss_multi_list, user_step_index, user_index_list, user_embeds, user_embed):
330 |
331 | infoNCELoss_list_weights = [None]*self.beh_num
332 | behavior_loss_multi_list_weights = [None]*self.beh_num
333 | for i in range(self.beh_num):
334 |
335 | #retailrocket--------------------------------------------------------------------------------------------------------------------------------------------------------------
336 | # SSL_input = args.inner_product_mult*torch.cat((infoNCELoss_list[i].unsqueeze(1).repeat(1, args.hidden_dim)*args.inner_product_mult, user_embeds[i][user_step_index]), 1) #[] [1024, 16]
337 | # SSL_input = args.inner_product_mult*torch.cat((SSL_input, user_embed[user_step_index]), 1)
338 | # SSL_input3 = args.inner_product_mult*((infoNCELoss_list[i].unsqueeze(1).repeat(1, args.hidden_dim*2))*torch.cat((user_embeds[i][user_step_index],user_embed[user_step_index]), 1))
339 |
340 | # infoNCELoss_list_weights[i] = self.dropout7(self.prelu(self.SSL_layer1(SSL_input)))
341 | # infoNCELoss_list_weights[i] = np.sqrt(SSL_input.shape[1])*self.dropout7(self.SSL_layer2(infoNCELoss_list_weights[i]).squeeze())
342 |
343 | # infoNCELoss_list_weights[i] = self.batch_norm(infoNCELoss_list_weights[i].unsqueeze(1)).squeeze()
344 | # infoNCELoss_list_weights[i] = args.inner_product_mult*self.sigmoid(infoNCELoss_list_weights[i])
345 | # SSL_weight3 = self.dropout7(self.prelu(self.SSL_layer3(SSL_input3)))
346 | # SSL_weight3 = self.batch_norm(SSL_weight3).squeeze()
347 |
348 | # SSL_weight3 = args.inner_product_mult*self.sigmoid(SSL_weight3)
349 | # infoNCELoss_list_weights[i] = (infoNCELoss_list_weights[i] + SSL_weight3)/2
350 |
351 | # RS_input = args.inner_product_mult*torch.cat((behavior_loss_multi_list[i].unsqueeze(1).repeat(1, args.hidden_dim)*args.inner_product_mult, user_embed[user_index_list[i]]), 1)
352 | # RS_input = args.inner_product_mult*torch.cat((RS_input, user_embeds[i][user_index_list[i]]), 1)
353 | # RS_input3 = args.inner_product_mult*((behavior_loss_multi_list[i].unsqueeze(1).repeat(1, args.hidden_dim))*user_embed[user_index_list[i]])
354 |
355 | # behavior_loss_multi_list_weights[i] = self.dropout7(self.prelu(self.RS_layer1(RS_input)))
356 | # behavior_loss_multi_list_weights[i] = np.sqrt(RS_input.shape[1])*self.dropout7(self.RS_layer2(behavior_loss_multi_list_weights[i]).squeeze())
357 | # behavior_loss_multi_list_weights[i] = self.batch_norm(behavior_loss_multi_list_weights[i].unsqueeze(1))
358 | # behavior_loss_multi_list_weights[i] = args.inner_product_mult*self.sigmoid(behavior_loss_multi_list_weights[i]).squeeze()
359 | # RS_weight3 = self.dropout7(self.prelu(self.RS_layer3(RS_input3)))
360 | # RS_weight3 = self.batch_norm(RS_weight3).squeeze()
361 | # RS_weight3 = args.inner_product_mult*self.sigmoid(RS_weight3).squeeze()
362 | # behavior_loss_multi_list_weights[i] = behavior_loss_multi_list_weights[i] + RS_weight3
363 | #retailrocket--------------------------------------------------------------------------------------------------------------------------------------------------------------
364 |
365 | #IJCAI,Tmall--------------------------------------------------------------------------------------------------------------------------------------------------------------
366 | SSL_input = args.inner_product_mult*torch.cat((infoNCELoss_list[i].unsqueeze(1).repeat(1, args.hidden_dim)*args.inner_product_mult, user_embeds[i][user_step_index]), 1) #[] [1024, 16]
367 | SSL_input = args.inner_product_mult*torch.cat((SSL_input, user_embed[user_step_index]), 1)
368 | SSL_input3 = args.inner_product_mult*((infoNCELoss_list[i].unsqueeze(1).repeat(1, args.hidden_dim*2))*torch.cat((user_embeds[i][user_step_index],user_embed[user_step_index]), 1))
369 |
370 | infoNCELoss_list_weights[i] = self.dropout7(self.prelu(self.SSL_layer1(SSL_input)))
371 | infoNCELoss_list_weights[i] = np.sqrt(SSL_input.shape[1])*self.dropout7(self.SSL_layer2(infoNCELoss_list_weights[i]).squeeze())
372 |
373 | infoNCELoss_list_weights[i] = self.batch_norm(infoNCELoss_list_weights[i].unsqueeze(1)).squeeze()
374 | infoNCELoss_list_weights[i] = args.inner_product_mult*self.sigmoid(infoNCELoss_list_weights[i])
375 | SSL_weight3 = self.dropout7(self.prelu(self.SSL_layer3(SSL_input3)))
376 | SSL_weight3 = self.batch_norm(SSL_weight3).squeeze()
377 |
378 | SSL_weight3 = args.inner_product_mult*self.sigmoid(SSL_weight3)
379 | infoNCELoss_list_weights[i] = (infoNCELoss_list_weights[i] + SSL_weight3)/2
380 |
381 | RS_input = args.inner_product_mult*torch.cat((behavior_loss_multi_list[i].unsqueeze(1).repeat(1, args.hidden_dim)*args.inner_product_mult, user_embed[user_index_list[i]]), 1)
382 | RS_input = args.inner_product_mult*torch.cat((RS_input, user_embeds[i][user_index_list[i]]), 1)
383 | RS_input3 = args.inner_product_mult*((behavior_loss_multi_list[i].unsqueeze(1).repeat(1, args.hidden_dim))*user_embed[user_index_list[i]])
384 | behavior_loss_multi_list_weights[i] = self.dropout7(self.prelu(self.RS_layer1(RS_input)))
385 | behavior_loss_multi_list_weights[i] = np.sqrt(RS_input.shape[1])*self.dropout7(self.RS_layer2(behavior_loss_multi_list_weights[i]).squeeze())
386 | behavior_loss_multi_list_weights[i] = self.batch_norm(behavior_loss_multi_list_weights[i].unsqueeze(1))
387 | behavior_loss_multi_list_weights[i] = args.inner_product_mult*self.sigmoid(behavior_loss_multi_list_weights[i]).squeeze()
388 | RS_weight3 = self.dropout7(self.prelu(self.RS_layer3(RS_input3)))
389 | RS_weight3 = self.batch_norm(RS_weight3).squeeze()
390 | RS_weight3 = args.inner_product_mult*self.sigmoid(RS_weight3).squeeze()
391 | behavior_loss_multi_list_weights[i] = behavior_loss_multi_list_weights[i] + RS_weight3
392 |
393 |
394 | return infoNCELoss_list_weights, behavior_loss_multi_list_weights
395 |
396 |
397 |
398 |
--------------------------------------------------------------------------------
/Params.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 |
4 | def parse_args():
5 | parser = argparse.ArgumentParser(description='Model Params')
6 |
7 |
8 | # #for this model
9 | # parser.add_argument('--hidden_dim', default=16, type=int, help='embedding size')
10 | # parser.add_argument('--gnn_layer', default="[16,16,16]", type=str, help='gnn layers: number + dim')
11 | # parser.add_argument('--dataset', default='IJCAI_15', type=str, help='name of dataset')
12 | # parser.add_argument('--point', default='for_meta_hidden_dim', type=str, help='')
13 | # parser.add_argument('--title', default='dim__8', type=str, help='title of model')
14 | # parser.add_argument('--sampNum', default=10, type=int, help='batch size for sampling')
15 |
16 | # #for train
17 | # parser.add_argument('--lr', default=3e-4, type=float, help='learning rate')
18 | # parser.add_argument('--opt_base_lr', default=1e-3, type=float, help='learning rate')
19 | # parser.add_argument('--opt_max_lr', default=2e-3, type=float, help='learning rate')
20 | # parser.add_argument('--opt_weight_decay', default=1e-4, type=float, help='weight decay regularizer')
21 | # parser.add_argument('--meta_opt_base_lr', default=1e-4, type=float, help='learning rate')
22 | # parser.add_argument('--meta_opt_max_lr', default=1e-3, type=float, help='learning rate')
23 | # parser.add_argument('--meta_opt_weight_decay', default=1e-4, type=float, help='weight decay regularizer')
24 | # parser.add_argument('--meta_lr', default=1e-3, type=float, help='_meta_learning rate')
25 |
26 | # parser.add_argument('--batch', default=8192, type=int, help='batch size')
27 | # parser.add_argument('--meta_batch', default=128, type=int, help='batch size')
28 | # parser.add_argument('--SSL_batch', default=30, type=int, help='batch size')
29 | # parser.add_argument('--reg', default=1e-3, type=float, help='weight decay regularizer')
30 | # parser.add_argument('--beta', default=0.005, type=float, help='scale of infoNCELoss')
31 | # parser.add_argument('--epoch', default=300, type=int, help='number of epochs')
32 | # # parser.add_argument('--decay', default=0.96, type=float, help='weight decay rate')
33 | # parser.add_argument('--shoot', default=10, type=int, help='K of top k')
34 | # parser.add_argument('--inner_product_mult', default=1, type=float, help='multiplier for the result')
35 | # parser.add_argument('--drop_rate', default=0.8, type=float, help='drop_rate')
36 | # parser.add_argument('--drop_rate1', default=0.5, type=float, help='drop_rate')
37 | # parser.add_argument('--seed', type=int, default=6)
38 | # parser.add_argument('--slope', type=float, default=0.1)
39 | # parser.add_argument('--patience', type=int, default=100)
40 | # #for save and read
41 | # parser.add_argument('--path', default='/home/ww/Code/MultiBehavior_BASELINE/MB-GCN/Datasets/', type=str, help='data path')
42 | # parser.add_argument('--save_path', default='tem', help='file name to save model and training record')
43 | # parser.add_argument('--load_model', default=None, help='model name to load')
44 | # parser.add_argument('--target', default='buy', type=str, help='target behavior to predict on')
45 | # parser.add_argument('--isload', default=False , type=bool, help='whether load model')
46 | # parser.add_argument('--isJustTest', default=False , type=bool, help='whether load model')
47 | # parser.add_argument('--loadModelPath', default='/home/ww/Code/work3/BSTRec/Model/IJCAI_15/for_meta_hidden_dim_dim__8_IJCAI_15_2021_07_10__14_11_55_lr_0.0003_reg_0.001_batch_size_4096_gnn_layer_[16,16,16].pth', type=str, help='loadModelPath')
48 | # # #Tmall: # loadPath_SSL_meta = "/home/ww/Code/work3/BSTRec/Model/Tmall/for_meta_hidden_dim_dim__8_Tmall_2021_07_08__01_35_54_lr_0.0003_reg_0.001_batch_size_4096_gnn_layer_[16,16,16].pth"
49 | # # #IJCAI_15: # loadPath_SSL_meta = "/home/ww/Code/work3/BSTRec/Model/IJCAI_15/for_meta_hidden_dim_dim__8_IJCAI_15_2021_07_10__14_11_55_lr_0.0003_reg_0.001_batch_size_4096_gnn_layer_[16,16,16].pth"
50 | # # #retailrocket: # loadPath_SSL_meta = "/home/ww/Code/work3/BSTRec/Model/retailrocket/for_meta_hidden_dim_dim__8_retailrocket_2021_07_10__18_35_32_lr_0.0003_reg_0.01_batch_size_1024_gnn_layer_[16,16,16].pth"
51 |
52 |
53 |
54 | # #use less
55 | # # parser.add_argument('--memosize', default=2, type=int, help='memory size')
56 | # parser.add_argument('--head_num', default=4, type=int, help='head_num_of_multihead_attention')
57 | # parser.add_argument('--beta_multi_behavior', default=0.005, type=float, help='scale of infoNCELoss')
58 | # parser.add_argument('--sampNum_slot', default=30, type=int, help='SSL_step')
59 | # parser.add_argument('--SSL_slot', default=1, type=int, help='SSL_step')
60 | # parser.add_argument('--k', default=2, type=float, help='MFB')
61 | # parser.add_argument('--meta_time_rate', default=0.8, type=float, help='gating rate')
62 | # parser.add_argument('--meta_behavior_rate', default=0.8, type=float, help='gating rate')
63 | # parser.add_argument('--meta_slot', default=2, type=int, help='epoch number for each SSL')
64 | # parser.add_argument('--time_slot', default=60*60*24*360, type=float, help='length of time slots')
65 | # parser.add_argument('--hidden_dim_meta', default=16, type=int, help='embedding size')
66 | # # parser.add_argument('--att_head', default=2, type=int, help='number of attention heads')
67 | # # parser.add_argument('--gnn_layer', default=2, type=int, help='number of gnn layers')
68 | # # parser.add_argument('--trnNum', default=10000, type=int, help='number of training instances per epoch')
69 | # # parser.add_argument('--deep_layer', default=0, type=int, help='number of deep layers to make the final prediction')
70 | # # parser.add_argument('--iiweight', default=0.3, type=float, help='weight for ii')
71 | # # parser.add_argument('--graphSampleN', default=10000, type=int, help='use 25000 for training and 200000 for testing, empirically')
72 | # # parser.add_argument('--divSize', default=1000, type=int, help='div size for smallTestEpoch')
73 | # # parser.add_argument('--tstEpoch', default=1, type=int, help='number of epoch to test while training')
74 | # # parser.add_argument('--subUsrSize', default=10, type=int, help='number of item for each sub-user')
75 | # # parser.add_argument('--subUsrDcy', default=0.9, type=float, help='decay factor for sub-users over time')
76 | # # parser.add_argument('--slot', default=0.5, type=float, help='length of time slots')
77 |
78 |
79 | return parser.parse_args()
80 | args = parse_args()
81 |
82 | #
83 | # args.user = 805506#147894
84 | # args.item = 584050#99037
85 | # ML10M
86 | # args.user = 67788
87 | # args.item = 8704
88 | # yelp
89 | # args.user = 19800
90 | # args.item = 22734
91 |
92 | # swap user and item
93 | # tem = args.user
94 | # args.user = args.item
95 | # args.item = tem
96 |
97 | # args.decay_step = args.trn_num
98 | # args.decay_step = args.item
99 | # args.decay_step = args.trnNum
100 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # CML
2 |
3 |
4 |
5 | This repository contains PyTorch codes and datasets for the paper:
6 |
7 | > Wei, Wei and Huang, Chao and Xia, Lianghao and Xu, Yong and Zhao, Jiashu and Yin, Dawei. Contrastive Meta Learning with Behavior Multiplicity forRecommendation. Paper in arXiv.
8 |
9 |
10 | ## Introduction
11 | Contrastive Meta Learning (CML) leverages multi-behavior learning paradigm to model diverse and multiplex user-item relationships, as well as tackling the label scarcity problem for target behaviors. The designed multi-behavior contrastive task is to capture the transferable user-item relationships from multi-typed user behavior data heterogeneity. And the proposed meta contrastive encoding scheme allows CML to preserve the personalized multi-behavior characteristics, so as to be reflective of the diverse behavior-aware user preference under a customized self-supervised framework.
12 |
13 |
14 | ## Citation
15 | ```
16 | @inproceedings{wei2022contrastive,
17 | title={Contrastive meta learning with behavior multiplicity for recommendation},
18 | author={Wei, Wei and Huang, Chao and Xia, Lianghao and Xu, Yong and Zhao, Jiashu and Yin, Dawei},
19 | booktitle={Proceedings of the Fifteenth ACM International Conference on Web Search and Data Mining},
20 | pages={1120--1128},
21 | year={2022}
22 | }
23 | ```
24 |
25 |
26 | ## Environment
27 |
28 | The codes of CML are implemented and tested under the following development environment:
29 |
30 | - Python 3.6
31 | - torch==1.8.1+cu111
32 | - scipy==1.6.2
33 | - tqdm==4.61.2
34 |
35 |
36 |
37 | ## Datasets
38 |
39 | #### Raw data:
40 | - IJCAI contest: https://tianchi.aliyun.com/dataset/dataDetail?dataId=47
41 | - Retail Rocket: https://www.kaggle.com/retailrocket/ecommerce-dataset
42 | - Tmall: https://tianchi.aliyun.com/dataset/dataDetail?dataId=649
43 | #### Processed data:
44 | - The processed IJCAI are under the /datasets folder.
45 |
46 |
47 | ## Usage
48 |
49 | The command to train CML on the Tmall/IJCAI/Retailrocket datasets are as follows. The commands specify the hyperparameter settings that generate the reported results in the paper.
50 |
51 | * Tmall
52 | ```
53 | python main.py --path=./datasets/ --dataset=Tmall --opt_base_lr=1e-3 --opt_max_lr=5e-3 --opt_weight_decay=1e-4 --meta_opt_base_lr=1e-4 --meta_opt_max_lr=2e-3 --meta_opt_weight_decay=1e-4 --meta_lr=1e-3 --batch=8192 --meta_batch=128 --SSL_batch=18
54 | ```
55 | * IJCAI
56 | ```
57 | python main.py --path=./datasets/ --dataset=IJCAI_15 --sampNum=10 --opt_base_lr=1e-3 --opt_max_lr=2e-3 --opt_weight_decay=1e-4 --meta_opt_base_lr=1e-4 --meta_opt_max_lr=1e-3 --meta_opt_weight_decay=1e-4 --meta_lr=1e-3 --batch=8192 --meta_batch=128 --SSL_batch=30
58 | ```
59 | * Retailrocket
60 | ```
61 | python main.py --path=./datasets/ --dataset='retailrocket' --sampNum=40 --lr=3e-4 --opt_base_lr=1e-4 --opt_max_lr=1e-3 --opt_weight_decay=1e-4 --opt_weight_decay=1e-4 --meta_opt_base_lr=1e-4 --meta_opt_max_lr=1e-3 --meta_opt_weight_decay=1e-3 --meta_lr=1e-3 --batch=2048 --meta_batch=128 --SSL_batch=15
62 | ```
63 |
64 |
65 |
66 |
72 |
73 |
74 |
75 |
76 |
77 |
79 |
80 |
81 |
82 |
83 |
84 | It will be released again in few days in the optimized code version.
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
--------------------------------------------------------------------------------
/Utils/README.md:
--------------------------------------------------------------------------------
1 | .
2 |
--------------------------------------------------------------------------------
/Utils/TimeLogger.py:
--------------------------------------------------------------------------------
1 | import datetime
2 |
3 | logmsg = ''
4 | timemark = dict()
5 | saveDefault = False
6 | def log(msg, save=None, oneline=False):
7 | global logmsg
8 | global saveDefault
9 | time = datetime.datetime.now()
10 | tem = '%s: %s' % (time, msg)
11 | if save != None:
12 | if save:
13 | logmsg += tem + '\n'
14 | elif saveDefault:
15 | logmsg += tem + '\n'
16 | if oneline:
17 | print(tem, end='\r')
18 | else:
19 | print(tem)
20 |
21 | def marktime(marker):
22 | global timemark
23 | timemark[marker] = datetime.datetime.now()
24 |
25 | def SpentTime(marker):
26 | global timemark
27 | if marker not in timemark:
28 | msg = 'LOGGER ERROR, marker', marker, ' not found'
29 | tem = '%s: %s' % (time, msg)
30 | print(tem)
31 | return False
32 | return datetime.datetime.now() - timemark[marker]
33 |
34 | def SpentTooLong(marker, day=0, hour=0, minute=0, second=0):
35 | global timemark
36 | if marker not in timemark:
37 | msg = 'LOGGER ERROR, marker', marker, ' not found'
38 | tem = '%s: %s' % (time, msg)
39 | print(tem)
40 | return False
41 | return datetime.datetime.now() - timemark[marker] >= datetime.timedelta(days=day, hours=hour, minutes=minute, seconds=second)
42 |
43 | if __name__ == '__main__':
44 | log('')
--------------------------------------------------------------------------------
/data/README.md:
--------------------------------------------------------------------------------
1 | .
2 |
--------------------------------------------------------------------------------
/graph_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import scipy.sparse as sp
3 | from scipy.sparse import *
4 | import torch
5 |
6 | from Params import args
7 |
8 |
9 | def get_use(behaviors_data):
10 |
11 | behavior_mats = {}
12 |
13 | behaviors_data = (behaviors_data != 0) * 1
14 |
15 | behavior_mats['A'] = matrix_to_tensor(normalize_adj(behaviors_data))
16 | behavior_mats['AT'] = matrix_to_tensor(normalize_adj(behaviors_data.T))
17 | behavior_mats['A_ori'] = None
18 |
19 | return behavior_mats
20 |
21 |
22 | def normalize_adj(adj):
23 | """Symmetrically normalize adjacency matrix."""
24 | adj = sp.coo_matrix(adj)
25 | rowsum = np.array(adj.sum(1))
26 | rowsum_diag = sp.diags(np.power(rowsum+1e-8, -0.5).flatten())
27 |
28 | colsum = np.array(adj.sum(0))
29 | colsum_diag = sp.diags(np.power(colsum+1e-8, -0.5).flatten())
30 |
31 |
32 | return adj
33 |
34 |
35 | def matrix_to_tensor(cur_matrix):
36 | if type(cur_matrix) != sp.coo_matrix:
37 | cur_matrix = cur_matrix.tocoo()
38 | indices = torch.from_numpy(np.vstack((cur_matrix.row, cur_matrix.col)).astype(np.int64))
39 | values = torch.from_numpy(cur_matrix.data)
40 | shape = torch.Size(cur_matrix.shape)
41 |
42 | return torch.sparse.FloatTensor(indices, values, shape).to(torch.float32).cuda()
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from numpy import random
3 | import pickle
4 | from scipy.sparse import csr_matrix
5 | import math
6 | import gc
7 | import time
8 | import random
9 | import datetime
10 |
11 | import torch as t
12 | import torch.nn as nn
13 | import torch.utils.data as dataloader
14 | import torch.nn.functional as F
15 | from torch.nn import init
16 |
17 | import graph_utils
18 | import DataHandler
19 |
20 |
21 | import BGNN
22 | import MV_Net
23 |
24 | from Params import args
25 | from Utils.TimeLogger import log
26 | from tqdm import tqdm
27 |
28 | t.backends.cudnn.benchmark=True
29 |
30 | if t.cuda.is_available():
31 | use_cuda = True
32 | else:
33 | use_cuda = False
34 |
35 | MAX_FLAG = 0x7FFFFFFF
36 |
37 | now_time = datetime.datetime.now()
38 | modelTime = datetime.datetime.strftime(now_time,'%Y_%m_%d__%H_%M_%S')
39 |
40 | t.autograd.set_detect_anomaly(True)
41 |
42 | class Model():
43 | def __init__(self):
44 |
45 |
46 | self.trn_file = args.path + args.dataset + '/trn_'
47 | self.tst_file = args.path + args.dataset + '/tst_int'
48 | # self.tst_file = args.path + args.dataset + '/BST_tst_int_59'
49 | #Tmall: 3,4,5,6,8,59
50 | #IJCAI_15: 5,6,8,10,13,53
51 |
52 | # self.meta_multi_file = args.path + args.dataset + '/meta_multi_beh_user_index'
53 | # self.meta_single_file = args.path + args.dataset + '/meta_single_beh_user_index'
54 | self.meta_multi_single_file = args.path + args.dataset + '/meta_multi_single_beh_user_index_shuffle'
55 | # /meta_multi_single_beh_user_index_shuffle
56 | # /new_multi_single
57 |
58 | # self.meta_multi = pickle.load(open(self.meta_multi_file, 'rb'))
59 | # self.meta_single = pickle.load(open(self.meta_single_file, 'rb'))
60 | self.meta_multi_single = pickle.load(open(self.meta_multi_single_file, 'rb'))
61 |
62 | self.t_max = -1
63 | self.t_min = 0x7FFFFFFF
64 | self.time_number = -1
65 |
66 | self.user_num = -1
67 | self.item_num = -1
68 | self.behavior_mats = {}
69 | self.behaviors = []
70 | self.behaviors_data = {}
71 |
72 | #history
73 | self.train_loss = []
74 | self.his_hr = []
75 | self.his_ndcg = []
76 | gc.collect() #
77 |
78 | self.relu = t.nn.ReLU()
79 | self.sigmoid = t.nn.Sigmoid()
80 | self.curEpoch = 0
81 |
82 |
83 | if args.dataset == 'Tmall':
84 | self.behaviors_SSL = ['pv','fav', 'cart', 'buy']
85 | self.behaviors = ['pv','fav', 'cart', 'buy']
86 | # self.behaviors = ['buy']
87 | elif args.dataset == 'IJCAI_15':
88 | self.behaviors = ['click','fav', 'cart', 'buy']
89 | # self.behaviors = ['buy']
90 | self.behaviors_SSL = ['click','fav', 'cart', 'buy']
91 |
92 | elif args.dataset == 'JD':
93 | self.behaviors = ['review','browse', 'buy']
94 | self.behaviors_SSL = ['review','browse', 'buy']
95 |
96 | elif args.dataset == 'retailrocket':
97 | self.behaviors = ['view','cart', 'buy']
98 | # self.behaviors = ['buy']
99 | self.behaviors_SSL = ['view','cart', 'buy']
100 |
101 |
102 | for i in range(0, len(self.behaviors)):
103 | with open(self.trn_file + self.behaviors[i], 'rb') as fs:
104 | data = pickle.load(fs)
105 | self.behaviors_data[i] = data
106 |
107 | if data.get_shape()[0] > self.user_num:
108 | self.user_num = data.get_shape()[0]
109 | if data.get_shape()[1] > self.item_num:
110 | self.item_num = data.get_shape()[1]
111 |
112 |
113 | if data.data.max() > self.t_max:
114 | self.t_max = data.data.max()
115 | if data.data.min() < self.t_min:
116 | self.t_min = data.data.min()
117 |
118 |
119 | if self.behaviors[i]==args.target:
120 | self.trainMat = data
121 | self.trainLabel = 1*(self.trainMat != 0)
122 | self.labelP = np.squeeze(np.array(np.sum(self.trainLabel, axis=0)))
123 |
124 |
125 | time = datetime.datetime.now()
126 | print("Start building: ", time)
127 | for i in range(0, len(self.behaviors)):
128 | self.behavior_mats[i] = graph_utils.get_use(self.behaviors_data[i])
129 | time = datetime.datetime.now()
130 | print("End building:", time)
131 |
132 |
133 | print("user_num: ", self.user_num)
134 | print("item_num: ", self.item_num)
135 | print("\n")
136 |
137 |
138 | #---------------------------------------------------------------------------------------------->>>>>
139 | #train_data
140 | train_u, train_v = self.trainMat.nonzero()
141 | train_data = np.hstack((train_u.reshape(-1,1), train_v.reshape(-1,1))).tolist()
142 | train_dataset = DataHandler.RecDataset_beh(self.behaviors, train_data, self.item_num, self.behaviors_data, True)
143 | self.train_loader = dataloader.DataLoader(train_dataset, batch_size=args.batch, shuffle=True, num_workers=4, pin_memory=True)
144 |
145 | #valid_data
146 |
147 |
148 | # test_data
149 | with open(self.tst_file, 'rb') as fs:
150 | data = pickle.load(fs)
151 |
152 | test_user = np.array([idx for idx, i in enumerate(data) if i is not None])
153 | test_item = np.array([i for idx, i in enumerate(data) if i is not None])
154 | # tstUsrs = np.reshape(np.argwhere(data!=None), [-1])
155 | test_data = np.hstack((test_user.reshape(-1,1), test_item.reshape(-1,1))).tolist()
156 | # testbatch = np.maximum(1, args.batch * args.sampNum
157 | test_dataset = DataHandler.RecDataset(test_data, self.item_num, self.trainMat, 0, False)
158 | self.test_loader = dataloader.DataLoader(test_dataset, batch_size=args.batch, shuffle=False, num_workers=4, pin_memory=True)
159 | # -------------------------------------------------------------------------------------------------->>>>>
160 |
161 | def prepareModel(self):
162 | self.modelName = self.getModelName()
163 | # self.setRandomSeed()
164 | self.gnn_layer = eval(args.gnn_layer)
165 | self.hidden_dim = args.hidden_dim
166 |
167 |
168 | if args.isload == True:
169 | self.loadModel(args.loadModelPath)
170 | else:
171 | self.model = BGNN.myModel(self.user_num, self.item_num, self.behaviors, self.behavior_mats).cuda()
172 | self.meta_weight_net = MV_Net.MetaWeightNet(len(self.behaviors)).cuda()
173 |
174 |
175 |
176 | # #IJCAI_15
177 | # self.opt = t.optim.AdamW(self.model.parameters(), lr = args.lr, weight_decay = args.opt_weight_decay)
178 | # self.meta_opt = t.optim.AdamW(self.meta_weight_net.parameters(), lr = args.meta_lr, weight_decay=args.meta_opt_weight_decay)
179 | # # self.meta_opt = t.optim.RMSprop(self.meta_weight_net.parameters(), lr = args.meta_lr, weight_decay=args.meta_opt_weight_decay, momentum=0.95, centered=True)
180 | # self.scheduler = t.optim.lr_scheduler.CyclicLR(self.opt, args.opt_base_lr, args.opt_max_lr, step_size_up=5, step_size_down=10, mode='triangular', gamma=0.99, scale_fn=None, scale_mode='cycle', cycle_momentum=False, base_momentum=0.8, max_momentum=0.9, last_epoch=-1)
181 | # self.meta_scheduler = t.optim.lr_scheduler.CyclicLR(self.meta_opt, args.meta_opt_base_lr, args.meta_opt_max_lr, step_size_up=2, step_size_down=3, mode='triangular', gamma=0.98, scale_fn=None, scale_mode='cycle', cycle_momentum=False, base_momentum=0.9, max_momentum=0.99, last_epoch=-1)
182 | # #
183 |
184 |
185 | #Tmall
186 | self.opt = t.optim.AdamW(self.model.parameters(), lr = args.lr, weight_decay = args.opt_weight_decay)
187 | self.meta_opt = t.optim.AdamW(self.meta_weight_net.parameters(), lr = args.meta_lr, weight_decay=args.meta_opt_weight_decay)
188 | # self.meta_opt = t.optim.RMSprop(self.meta_weight_net.parameters(), lr = args.meta_lr, weight_decay=args.meta_opt_weight_decay, momentum=0.95, centered=True)
189 | self.scheduler = t.optim.lr_scheduler.CyclicLR(self.opt, args.opt_base_lr, args.opt_max_lr, step_size_up=5, step_size_down=10, mode='triangular', gamma=0.99, scale_fn=None, scale_mode='cycle', cycle_momentum=False, base_momentum=0.8, max_momentum=0.9, last_epoch=-1)
190 | self.meta_scheduler = t.optim.lr_scheduler.CyclicLR(self.meta_opt, args.meta_opt_base_lr, args.meta_opt_max_lr, step_size_up=3, step_size_down=7, mode='triangular', gamma=0.98, scale_fn=None, scale_mode='cycle', cycle_momentum=False, base_momentum=0.9, max_momentum=0.99, last_epoch=-1)
191 | # 0.993
192 |
193 | # # retailrocket
194 | # self.opt = t.optim.AdamW(self.model.parameters(), lr = args.lr, weight_decay = args.opt_weight_decay)
195 | # # self.meta_opt = t.optim.AdamW(self.meta_weight_net.parameters(), lr = args.meta_lr, weight_decay=args.meta_opt_weight_decay)
196 | # self.meta_opt = t.optim.SGD(self.meta_weight_net.parameters(), lr = args.meta_lr, weight_decay=args.meta_opt_weight_decay, momentum=0.95, nesterov=True)
197 | # self.scheduler = t.optim.lr_scheduler.CyclicLR(self.opt, args.opt_base_lr, args.opt_max_lr, step_size_up=1, step_size_down=2, mode='triangular', gamma=0.99, scale_fn=None, scale_mode='cycle', cycle_momentum=False, base_momentum=0.8, max_momentum=0.9, last_epoch=-1)
198 | # self.meta_scheduler = t.optim.lr_scheduler.CyclicLR(self.meta_opt, args.meta_opt_base_lr, args.meta_opt_max_lr, step_size_up=1, step_size_down=2, mode='triangular', gamma=0.99, scale_fn=None, scale_mode='cycle', cycle_momentum=False, base_momentum=0.9, max_momentum=0.99, last_epoch=-1)
199 | # # exp_range
200 |
201 |
202 | if use_cuda:
203 | self.model = self.model.cuda()
204 |
205 | def innerProduct(self, u, i, j):
206 | pred_i = t.sum(t.mul(u,i), dim=1)*args.inner_product_mult
207 | pred_j = t.sum(t.mul(u,j), dim=1)*args.inner_product_mult
208 | return pred_i, pred_j
209 |
210 | def SSL(self, user_embeddings, item_embeddings, target_user_embeddings, target_item_embeddings, user_step_index):
211 | def row_shuffle(embedding):
212 | corrupted_embedding = embedding[t.randperm(embedding.size()[0])]
213 | return corrupted_embedding
214 | def row_column_shuffle(embedding):
215 | corrupted_embedding = embedding[t.randperm(embedding.size()[0])]
216 | corrupted_embedding = corrupted_embedding[:,t.randperm(corrupted_embedding.size()[1])]
217 | return corrupted_embedding
218 | def score(x1, x2):
219 | return t.sum(t.mul(x1, x2), 1)
220 |
221 | def neg_sample_pair(x1, x2, τ = 0.05):
222 | for i in range(x1.shape[0]):
223 | index_set = set(np.arange(x1.shape[0]))
224 | index_set.remove(i)
225 | index_set_neg = t.as_tensor(np.array(list(index_set))).long().cuda()
226 |
227 | x_pos = x1[i].repeat(x1.shape[0]-1, 1)
228 | x_neg = x2[index_set]
229 |
230 | if i==0:
231 | x_pos_all = x_pos
232 | x_neg_all = x_neg
233 | else:
234 | x_pos_all = t.cat((x_pos_all, x_pos), 0)
235 | x_neg_all = t.cat((x_neg_all, x_neg), 0)
236 | x_pos_all = t.as_tensor(x_pos_all) #[9900, 100]
237 | x_neg_all = t.as_tensor(x_neg_all) #[9900, 100]
238 |
239 | return x_pos_all, x_neg_all
240 |
241 | def one_neg_sample_pair_index(i, step_index, embedding1, embedding2):
242 |
243 | index_set = set(np.array(step_index))
244 | index_set.remove(i.item())
245 | neg2_index = t.as_tensor(np.array(list(index_set))).long().cuda()
246 |
247 | neg1_index = t.ones((2,), dtype=t.long)
248 | neg1_index = neg1_index.new_full((len(index_set),), i)
249 |
250 | neg_score_pre = t.sum(compute(embedding1, embedding2, neg1_index, neg2_index).squeeze())
251 | return neg_score_pre
252 |
253 | def multi_neg_sample_pair_index(batch_index, step_index, embedding1, embedding2): #small, big, target, beh: [100], [1024], [31882, 16], [31882, 16]
254 |
255 | index_set = set(np.array(step_index.cpu()))
256 | batch_index_set = set(np.array(batch_index.cpu()))
257 | neg2_index_set = index_set - batch_index_set #beh
258 | neg2_index = t.as_tensor(np.array(list(neg2_index_set))).long().cuda() #[910]
259 | neg2_index = t.unsqueeze(neg2_index, 0) #[1, 910]
260 | neg2_index = neg2_index.repeat(len(batch_index), 1) #[100, 910]
261 | neg2_index = t.reshape(neg2_index, (1, -1)) #[1, 91000]
262 | neg2_index = t.squeeze(neg2_index) #[91000]
263 | #target
264 | neg1_index = batch_index.long().cuda() #[100]
265 | neg1_index = t.unsqueeze(neg1_index, 1) #[100, 1]
266 | neg1_index = neg1_index.repeat(1, len(neg2_index_set)) #[100, 910]
267 | neg1_index = t.reshape(neg1_index, (1, -1)) #[1, 91000]
268 | neg1_index = t.squeeze(neg1_index) #[91000]
269 |
270 | neg_score_pre = t.sum(compute(embedding1, embedding2, neg1_index, neg2_index).squeeze().view(len(batch_index), -1), -1) #[91000,1]==>[91000]==>[100, 910]==>[100]
271 | return neg_score_pre #[100]
272 |
273 | def compute(x1, x2, neg1_index=None, neg2_index=None, τ = 0.05): #[1024, 16], [1024, 16]
274 |
275 | if neg1_index!=None:
276 | x1 = x1[neg1_index]
277 | x2 = x2[neg2_index]
278 |
279 | N = x1.shape[0]
280 | D = x1.shape[1]
281 |
282 | x1 = x1
283 | x2 = x2
284 |
285 | scores = t.exp(t.div(t.bmm(x1.view(N, 1, D), x2.view(N, D, 1)).view(N, 1), np.power(D, 1)+1e-8)) #[1024, 1]
286 |
287 | return scores
288 | def single_infoNCE_loss_simple(embedding1, embedding2):
289 | pos = score(embedding1, embedding2) #[100]
290 | neg1 = score(embedding2, row_column_shuffle(embedding1))
291 | one = t.cuda.FloatTensor(neg1.shape[0]).fill_(1) #[100]
292 | # one = zeros = t.ones(neg1.shape[0])
293 | con_loss = t.sum(-t.log(1e-8 + t.sigmoid(pos))-t.log(1e-8 + (one - t.sigmoid(neg1))))
294 | return con_loss
295 |
296 | #use_less
297 | def single_infoNCE_loss(embedding1, embedding2):
298 | N = embedding1.shape[0]
299 | D = embedding1.shape[1]
300 |
301 | pos_score = compute(embedding1, embedding2).squeeze() #[100, 1]
302 |
303 | neg_x1, neg_x2 = neg_sample_pair(embedding1, embedding2) #[9900, 100], [9900, 100]
304 | neg_score = t.sum(compute(neg_x1, neg_x2).view(N, (N-1)), dim=1) #[100]
305 | con_loss = -t.log(1e-8 +t.div(pos_score, neg_score))
306 | con_loss = t.mean(con_loss)
307 | return max(0, con_loss)
308 |
309 | def single_infoNCE_loss_one_by_one(embedding1, embedding2, step_index): #target, beh
310 | N = step_index.shape[0]
311 | D = embedding1.shape[1]
312 |
313 | pos_score = compute(embedding1[step_index], embedding2[step_index]).squeeze() #[1024]
314 | neg_score = t.zeros((N,), dtype = t.float64).cuda() #[1024]
315 |
316 | #-------------------------------------------------multi version-----------------------------------------------------
317 | steps = int(np.ceil(N / args.SSL_batch)) #separate the batch to smaller one
318 | for i in range(steps):
319 | st = i * args.SSL_batch
320 | ed = min((i+1) * args.SSL_batch, N)
321 | batch_index = step_index[st: ed]
322 |
323 | neg_score_pre = multi_neg_sample_pair_index(batch_index, step_index, embedding1, embedding2)
324 | if i ==0:
325 | neg_score = neg_score_pre
326 | else:
327 | neg_score = t.cat((neg_score, neg_score_pre), 0)
328 | #-------------------------------------------------multi version-----------------------------------------------------
329 |
330 | con_loss = -t.log(1e-8 +t.div(pos_score, neg_score+1e-8)) #[1024]/[1024]==>1024
331 |
332 |
333 | assert not t.any(t.isnan(con_loss))
334 | assert not t.any(t.isinf(con_loss))
335 |
336 | return t.where(t.isnan(con_loss), t.full_like(con_loss, 0+1e-8), con_loss)
337 |
338 | user_con_loss_list = []
339 | item_con_loss_list = []
340 |
341 | SSL_len = int(user_step_index.shape[0]/10)
342 | user_step_index = t.as_tensor(np.random.choice(user_step_index.cpu(), size=SSL_len, replace=False, p=None)).cuda()
343 |
344 | for i in range(len(self.behaviors_SSL)):
345 |
346 | user_con_loss_list.append(single_infoNCE_loss_one_by_one(user_embeddings[-1], user_embeddings[i], user_step_index))
347 |
348 | user_con_losss = t.stack(user_con_loss_list, dim=0)
349 |
350 | return user_con_loss_list, user_step_index #4*[1024]
351 |
352 | def run(self):
353 |
354 | self.prepareModel()
355 | if args.isload == True:
356 | print("----------------------pre test:")
357 | HR, NDCG = self.testEpoch(self.test_loader)
358 | print(f"HR: {HR} , NDCG: {NDCG}")
359 | log('Model Prepared')
360 |
361 |
362 | cvWait = 0
363 | self.best_HR = 0
364 | self.best_NDCG = 0
365 | flag = 0
366 |
367 | self.user_embed = None
368 | self.item_embed = None
369 | self.user_embeds = None
370 | self.item_embeds = None
371 |
372 |
373 | print("Test before train:")
374 | HR, NDCG = self.testEpoch(self.test_loader)
375 |
376 | for e in range(self.curEpoch, args.epoch+1):
377 | self.curEpoch = e
378 |
379 | self.meta_flag = 0
380 | if e%args.meta_slot == 0:
381 | self.meta_flag=1
382 |
383 |
384 | log("*****************Start epoch: %d ************************"%e)
385 |
386 | if args.isJustTest == False:
387 | epoch_loss, user_embed, item_embed, user_embeds, item_embeds = self.trainEpoch()
388 | self.train_loss.append(epoch_loss)
389 | print(f"epoch {e/args.epoch}, epoch loss{epoch_loss}")
390 | self.train_loss.append(epoch_loss)
391 | else:
392 | break
393 |
394 | HR, NDCG = self.testEpoch(self.test_loader)
395 | self.his_hr.append(HR)
396 | self.his_ndcg.append(NDCG)
397 |
398 | self.scheduler.step()
399 | self.meta_scheduler.step()
400 |
401 | if HR > self.best_HR:
402 | self.best_HR = HR
403 | self.best_epoch = self.curEpoch
404 | cvWait = 0
405 | print("--------------------------------------------------------------------------------------------------------------------------best_HR", self.best_HR)
406 | # print("--------------------------------------------------------------------------------------------------------------------------NDCG", self.best_NDCG)
407 | self.user_embed = user_embed
408 | self.item_embed = item_embed
409 | self.user_embeds = user_embeds
410 | self.item_embeds = item_embeds
411 |
412 | self.saveHistory()
413 | self.saveModel()
414 |
415 |
416 |
417 | if NDCG > self.best_NDCG:
418 | self.best_NDCG = NDCG
419 | self.best_epoch = self.curEpoch
420 | cvWait = 0
421 | # print("--------------------------------------------------------------------------------------------------------------------------HR", self.best_HR)
422 | print("--------------------------------------------------------------------------------------------------------------------------best_NDCG", self.best_NDCG)
423 | self.user_embed = user_embed
424 | self.item_embed = item_embed
425 | self.user_embeds = user_embeds
426 | self.item_embeds = item_embeds
427 |
428 | self.saveHistory()
429 | self.saveModel()
430 |
431 |
432 |
433 | if (HR