├── BGNN.py ├── DataHandler.py ├── MV_Net.py ├── Params.py ├── README.md ├── Utils ├── README.md └── TimeLogger.py ├── data └── README.md ├── graph_utils.py └── main.py /BGNN.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.nn import init 6 | from torch.autograd import Variable 7 | from Params import args 8 | 9 | def to_var(x, requires_grad=True): 10 | if torch.cuda.is_available(): 11 | x = x.cuda() 12 | return Variable(x, requires_grad=requires_grad) 13 | 14 | 15 | class myModel(nn.Module): 16 | def __init__(self, userNum, itemNum, behavior, behavior_mats): 17 | super(myModel, self).__init__() 18 | 19 | self.userNum = userNum 20 | self.itemNum = itemNum 21 | self.behavior = behavior 22 | self.behavior_mats = behavior_mats 23 | 24 | self.embedding_dict = self.init_embedding() 25 | self.weight_dict = self.init_weight() 26 | self.gcn = GCN(self.userNum, self.itemNum, self.behavior, self.behavior_mats) 27 | 28 | 29 | def init_embedding(self): 30 | 31 | embedding_dict = { 32 | 'user_embedding': None, 33 | 'item_embedding': None, 34 | 'user_embeddings': None, 35 | 'item_embeddings': None, 36 | } 37 | return embedding_dict 38 | 39 | def init_weight(self): 40 | initializer = nn.init.xavier_uniform_ 41 | 42 | weight_dict = nn.ParameterDict({ 43 | 'w_self_attention_item': nn.Parameter(initializer(torch.empty([args.hidden_dim, args.hidden_dim]))), 44 | 'w_self_attention_user': nn.Parameter(initializer(torch.empty([args.hidden_dim, args.hidden_dim]))), 45 | 'w_self_attention_cat': nn.Parameter(initializer(torch.empty([args.head_num*args.hidden_dim, args.hidden_dim]))), 46 | 'alpha': nn.Parameter(torch.ones(2)), 47 | }) 48 | return weight_dict 49 | 50 | 51 | def forward(self): 52 | 53 | user_embed, item_embed, user_embeds, item_embeds = self.gcn() 54 | 55 | 56 | return user_embed, item_embed, user_embeds, item_embeds 57 | 58 | def para_dict_to_tenser(self, para_dict): 59 | 60 | tensors = [] 61 | for beh in para_dict.keys(): 62 | tensors.append(para_dict[beh]) 63 | tensors = torch.stack(tensors, dim=0) 64 | 65 | return tensors.float() 66 | 67 | def update_params(self, lr_inner, first_order=False, source_params=None, detach=False): 68 | if source_params is not None: 69 | for tgt, src in zip(self.named_parameters(), source_params): 70 | name_t, param_t = tgt 71 | grad = src 72 | if first_order: 73 | grad = to_var(grad.detach().data) 74 | tmp = param_t - lr_inner * grad 75 | self.set_param(self, name_t, tmp) 76 | else: 77 | 78 | for name, param in self.named_parameters()(self): 79 | if not detach: 80 | grad = param.grad 81 | if first_order: 82 | grad = to_var(grad.detach().data) 83 | tmp = param - lr_inner * grad 84 | self.set_param(self, name, tmp) 85 | else: 86 | param = param.detach_() 87 | self.set_param(self, name, param) 88 | 89 | 90 | class GCN(nn.Module): 91 | def __init__(self, userNum, itemNum, behavior, behavior_mats): 92 | super(GCN, self).__init__() 93 | self.userNum = userNum 94 | self.itemNum = itemNum 95 | self.hidden_dim = args.hidden_dim 96 | 97 | self.behavior = behavior 98 | self.behavior_mats = behavior_mats 99 | 100 | self.user_embedding, self.item_embedding = self.init_embedding() 101 | 102 | self.alpha, self.i_concatenation_w, self.u_concatenation_w, self.i_input_w, self.u_input_w = self.init_weight() 103 | 104 | self.sigmoid = torch.nn.Sigmoid() 105 | self.act = torch.nn.PReLU() 106 | self.dropout = torch.nn.Dropout(args.drop_rate) 107 | 108 | self.gnn_layer = eval(args.gnn_layer) 109 | self.layers = nn.ModuleList() 110 | for i in range(0, len(self.gnn_layer)): 111 | self.layers.append(GCNLayer(args.hidden_dim, args.hidden_dim, self.userNum, self.itemNum, self.behavior, self.behavior_mats)) 112 | 113 | def init_embedding(self): 114 | user_embedding = torch.nn.Embedding(self.userNum, args.hidden_dim) 115 | item_embedding = torch.nn.Embedding(self.itemNum, args.hidden_dim) 116 | nn.init.xavier_uniform_(user_embedding.weight) 117 | nn.init.xavier_uniform_(item_embedding.weight) 118 | 119 | return user_embedding, item_embedding 120 | 121 | def init_weight(self): 122 | alpha = nn.Parameter(torch.ones(2)) 123 | i_concatenation_w = nn.Parameter(torch.Tensor(len(eval(args.gnn_layer))*args.hidden_dim, args.hidden_dim)) 124 | u_concatenation_w = nn.Parameter(torch.Tensor(len(eval(args.gnn_layer))*args.hidden_dim, args.hidden_dim)) 125 | i_input_w = nn.Parameter(torch.Tensor(args.hidden_dim, args.hidden_dim)) 126 | u_input_w = nn.Parameter(torch.Tensor(args.hidden_dim, args.hidden_dim)) 127 | init.xavier_uniform_(i_concatenation_w) 128 | init.xavier_uniform_(u_concatenation_w) 129 | init.xavier_uniform_(i_input_w) 130 | init.xavier_uniform_(u_input_w) 131 | # init.xavier_uniform_(alpha) 132 | 133 | return alpha, i_concatenation_w, u_concatenation_w, i_input_w, u_input_w 134 | 135 | def forward(self, user_embedding_input=None, item_embedding_input=None): 136 | all_user_embeddings = [] 137 | all_item_embeddings = [] 138 | all_user_embeddingss = [] 139 | all_item_embeddingss = [] 140 | 141 | user_embedding = self.user_embedding.weight 142 | item_embedding = self.item_embedding.weight 143 | 144 | for i, layer in enumerate(self.layers): 145 | 146 | user_embedding, item_embedding, user_embeddings, item_embeddings = layer(user_embedding, item_embedding) 147 | 148 | norm_user_embeddings = F.normalize(user_embedding, p=2, dim=1) 149 | norm_item_embeddings = F.normalize(item_embedding, p=2, dim=1) 150 | 151 | all_user_embeddings.append(user_embedding) 152 | all_item_embeddings.append(item_embedding) 153 | all_user_embeddingss.append(user_embeddings) 154 | all_item_embeddingss.append(item_embeddings) 155 | 156 | user_embedding = torch.cat(all_user_embeddings, dim=1) 157 | item_embedding = torch.cat(all_item_embeddings, dim=1) 158 | user_embeddings = torch.cat(all_user_embeddingss, dim=2) 159 | item_embeddings = torch.cat(all_item_embeddingss, dim=2) 160 | 161 | user_embedding = torch.matmul(user_embedding , self.u_concatenation_w) 162 | item_embedding = torch.matmul(item_embedding , self.i_concatenation_w) 163 | user_embeddings = torch.matmul(user_embeddings , self.u_concatenation_w) 164 | item_embeddings = torch.matmul(item_embeddings , self.i_concatenation_w) 165 | 166 | 167 | return user_embedding, item_embedding, user_embeddings, item_embeddings #[31882, 16], [31882, 16], [4, 31882, 16], [4, 31882, 16] 168 | 169 | 170 | class GCNLayer(nn.Module): 171 | def __init__(self, in_dim, out_dim, userNum, itemNum, behavior, behavior_mats): 172 | super(GCNLayer, self).__init__() 173 | 174 | self.behavior = behavior 175 | self.behavior_mats = behavior_mats 176 | 177 | self.userNum = userNum 178 | self.itemNum = itemNum 179 | 180 | self.act = torch.nn.Sigmoid() 181 | self.i_w = nn.Parameter(torch.Tensor(in_dim, out_dim)) 182 | self.u_w = nn.Parameter(torch.Tensor(in_dim, out_dim)) 183 | self.ii_w = nn.Parameter(torch.Tensor(in_dim, out_dim)) 184 | init.xavier_uniform_(self.i_w) 185 | init.xavier_uniform_(self.u_w) 186 | 187 | def forward(self, user_embedding, item_embedding): 188 | 189 | user_embedding_list = [None]*len(self.behavior) 190 | item_embedding_list = [None]*len(self.behavior) 191 | 192 | for i in range(len(self.behavior)): 193 | user_embedding_list[i] = torch.spmm(self.behavior_mats[i]['A'], item_embedding) 194 | item_embedding_list[i] = torch.spmm(self.behavior_mats[i]['AT'], user_embedding) 195 | 196 | 197 | 198 | user_embeddings = torch.stack(user_embedding_list, dim=0) 199 | item_embeddings = torch.stack(item_embedding_list, dim=0) 200 | 201 | 202 | user_embedding = self.act(torch.matmul(torch.mean(user_embeddings, dim=0), self.u_w)) 203 | item_embedding = self.act(torch.matmul(torch.mean(item_embeddings, dim=0), self.i_w)) 204 | 205 | user_embeddings = self.act(torch.matmul(user_embeddings, self.u_w)) 206 | item_embeddings = self.act(torch.matmul(item_embeddings, self.i_w)) 207 | 208 | 209 | return user_embedding, item_embedding, user_embeddings, item_embeddings 210 | 211 | #------------------------------------------------------------------------------------------------------------------------------------------------ 212 | 213 | -------------------------------------------------------------------------------- /DataHandler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data as data 3 | import pickle 4 | import numpy as np 5 | import scipy.sparse as sp 6 | from math import ceil 7 | import datetime 8 | 9 | from Params import args 10 | import graph_utils 11 | 12 | 13 | 14 | class RecDataset(data.Dataset): 15 | def __init__(self, data, num_item, train_mat=None, num_ng=1, is_training=True): 16 | super(RecDataset, self).__init__() 17 | 18 | self.data = np.array(data) 19 | self.num_item = num_item 20 | self.train_mat = train_mat 21 | self.is_training = is_training 22 | 23 | def ng_sample(self): 24 | assert self.is_training, 'no need to sampling when testing' 25 | dok_trainMat = self.train_mat.todok() 26 | length = self.data.shape[0] 27 | self.neg_data = np.random.randint(low=0, high=self.num_item, size=length) 28 | 29 | for i in range(length): # 30 | uid = self.data[i][0] 31 | iid = self.neg_data[i] 32 | if (uid, iid) in dok_trainMat: 33 | while (uid, iid) in dok_trainMat: 34 | iid = np.random.randint(low=0, high=self.num_item) 35 | self.neg_data[i] = iid 36 | self.neg_data[i] = iid 37 | 38 | def __len__(self): 39 | return len(self.data) 40 | 41 | def __getitem__(self, idx): 42 | user = self.data[idx][0] 43 | item_i = self.data[idx][1] 44 | 45 | if self.is_training: 46 | neg_data = self.neg_data 47 | item_j = neg_data[idx] 48 | return user, item_i, item_j 49 | else: 50 | return user, item_i 51 | 52 | def getMatrix(self): 53 | pass 54 | 55 | def getAdj(self): 56 | pass 57 | 58 | def sampleLargeGraph(self): 59 | 60 | 61 | def makeMask(): 62 | pass 63 | 64 | def updateBdgt(): 65 | pass 66 | 67 | def sample(): 68 | pass 69 | 70 | def constructData(self): 71 | pass 72 | 73 | 74 | 75 | 76 | class RecDataset_beh(data.Dataset): 77 | def __init__(self, beh, data, num_item, behaviors_data=None, num_ng=1, is_training=True): 78 | super(RecDataset_beh, self).__init__() 79 | 80 | self.data = np.array(data) 81 | self.num_item = num_item 82 | self.is_training = is_training 83 | self.beh = beh 84 | self.behaviors_data = behaviors_data 85 | 86 | self.length = self.data.shape[0] 87 | self.neg_data = [None]*self.length 88 | self.pos_data = [None]*self.length 89 | 90 | def ng_sample(self): 91 | assert self.is_training, 'no need to sampling when testing' 92 | 93 | for i in range(self.length): 94 | self.neg_data[i] = [None]*len(self.beh) 95 | self.pos_data[i] = [None]*len(self.beh) 96 | 97 | for index in range(len(self.beh)): 98 | 99 | 100 | train_u, train_v = self.behaviors_data[index].nonzero() 101 | beh_dok = self.behaviors_data[index].todok() 102 | 103 | set_pos = np.array(list(set(train_v))) 104 | 105 | self.pos_data_index = np.random.choice(set_pos, size=self.length, replace=True, p=None) 106 | self.neg_data_index = np.random.randint(low=0, high=self.num_item, size=self.length) 107 | 108 | 109 | for i in range(self.length): # 110 | 111 | uid = self.data[i][0] 112 | iid_neg = self.neg_data[i][index] = self.neg_data_index[i] 113 | iid_pos = self.pos_data[i][index] = self.pos_data_index[i] 114 | 115 | if (uid, iid_neg) in beh_dok: 116 | while (uid, iid_neg) in beh_dok: 117 | iid_neg = np.random.randint(low=0, high=self.num_item) 118 | self.neg_data[i][index] = iid_neg 119 | self.neg_data[i][index] = iid_neg 120 | 121 | if index == (len(self.beh)-1): 122 | self.pos_data[i][index] = train_v[i] 123 | elif (uid, iid_pos) not in beh_dok: 124 | if len(self.behaviors_data[index][uid].data)==0: 125 | self.pos_data[i][index] = -1 126 | else: 127 | t_array = self.behaviors_data[index][uid].toarray() 128 | pos_index = np.where(t_array!=0)[1] 129 | iid_pos = np.random.choice(pos_index, size = 1, replace=True, p=None)[0] 130 | self.pos_data[i][index] = iid_pos 131 | 132 | def __len__(self): 133 | return len(self.data) 134 | 135 | def __getitem__(self, idx): 136 | user = self.data[idx][0] 137 | item_i = self.pos_data[idx] 138 | 139 | if self.is_training: 140 | item_j = self.neg_data[idx] 141 | return user, item_i, item_j 142 | else: 143 | return user, item_i 144 | 145 | -------------------------------------------------------------------------------- /MV_Net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | from torch.autograd import Variable 6 | import torch.nn.init as init 7 | 8 | from torch.nn import init 9 | from Params import args 10 | 11 | import numpy as np 12 | 13 | 14 | 15 | def to_var(x, requires_grad=True): 16 | if torch.cuda.is_available(): 17 | x = x.cuda() 18 | return Variable(x, requires_grad=requires_grad) 19 | 20 | 21 | class MetaModule(nn.Module): 22 | # adopted from: Adrien Ecoffet https://github.com/AdrienLE 23 | def params(self): 24 | for name, param in self.named_params(self): 25 | yield param 26 | 27 | def named_leaves(self): 28 | return [] 29 | 30 | def named_submodules(self): 31 | return [] 32 | 33 | def named_params(self, curr_module=None, memo=None, prefix=''): 34 | if memo is None: 35 | memo = set() 36 | 37 | if hasattr(curr_module, 'named_leaves'): 38 | for name, p in curr_module.named_leaves(): 39 | if p is not None and p not in memo: 40 | memo.add(p) 41 | yield prefix + ('.' if prefix else '') + name, p 42 | else: 43 | for name, p in curr_module._parameters.items(): 44 | if p is not None and p not in memo: 45 | memo.add(p) 46 | yield prefix + ('.' if prefix else '') + name, p 47 | 48 | for mname, module in curr_module.named_children(): 49 | submodule_prefix = prefix + ('.' if prefix else '') + mname 50 | for name, p in self.named_params(module, memo, submodule_prefix): 51 | yield name, p 52 | 53 | def update_params(self, lr_inner, first_order=False, source_params=None, detach=False): 54 | if source_params is not None: 55 | for tgt, src in zip(self.named_params(self), source_params): 56 | name_t, param_t = tgt 57 | # name_s, param_s = src 58 | # grad = param_s.grad 59 | # name_s, param_s = src 60 | grad = src 61 | if first_order: 62 | grad = to_var(grad.detach().data) 63 | tmp = param_t - lr_inner * grad 64 | self.set_param(self, name_t, tmp) 65 | else: 66 | 67 | for name, param in self.named_params(self): 68 | if not detach: 69 | grad = param.grad 70 | if first_order: 71 | grad = to_var(grad.detach().data) 72 | tmp = param - lr_inner * grad 73 | self.set_param(self, name, tmp) 74 | else: 75 | param = param.detach_() # https://blog.csdn.net/qq_39709535/article/details/81866686 76 | self.set_param(self, name, param) 77 | 78 | def set_param(self, curr_mod, name, param): 79 | if '.' in name: 80 | n = name.split('.') 81 | module_name = n[0] 82 | rest = '.'.join(n[1:]) 83 | for name, mod in curr_mod.named_children(): 84 | if module_name == name: 85 | self.set_param(mod, rest, param) 86 | break 87 | else: 88 | setattr(curr_mod, name, param) 89 | 90 | def detach_params(self): 91 | for name, param in self.named_params(self): 92 | self.set_param(self, name, param.detach()) 93 | 94 | def copy(self, other, same_var=False): 95 | for name, param in other.named_params(): 96 | if not same_var: 97 | param = to_var(param.data.clone(), requires_grad=True) 98 | self.set_param(name, param) 99 | 100 | 101 | class MetaLinear(MetaModule): 102 | def __init__(self, *args, **kwargs): 103 | super().__init__() 104 | ignore = nn.Linear(*args, **kwargs) 105 | 106 | self.register_buffer('weight', to_var(ignore.weight.data, requires_grad=True)) 107 | self.register_buffer('bias', to_var(ignore.bias.data, requires_grad=True)) 108 | 109 | def forward(self, x): 110 | return F.linear(x, self.weight, self.bias) 111 | 112 | def named_leaves(self): 113 | return [('weight', self.weight), ('bias', self.bias)] 114 | 115 | 116 | class MetaConv2d(MetaModule): 117 | def __init__(self, *args, **kwargs): 118 | super().__init__() 119 | ignore = nn.Conv2d(*args, **kwargs) 120 | 121 | self.in_channels = ignore.in_channels 122 | self.out_channels = ignore.out_channels 123 | self.stride = ignore.stride 124 | self.padding = ignore.padding 125 | self.dilation = ignore.dilation 126 | self.groups = ignore.groups 127 | self.kernel_size = ignore.kernel_size 128 | 129 | self.register_buffer('weight', to_var(ignore.weight.data, requires_grad=True)) 130 | 131 | if ignore.bias is not None: 132 | self.register_buffer('bias', to_var(ignore.bias.data, requires_grad=True)) 133 | else: 134 | self.register_buffer('bias', None) 135 | 136 | def forward(self, x): 137 | return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 138 | 139 | def named_leaves(self): 140 | return [('weight', self.weight), ('bias', self.bias)] 141 | 142 | 143 | class MetaConvTranspose2d(MetaModule): 144 | def __init__(self, *args, **kwargs): 145 | super().__init__() 146 | ignore = nn.ConvTranspose2d(*args, **kwargs) 147 | 148 | self.stride = ignore.stride 149 | self.padding = ignore.padding 150 | self.dilation = ignore.dilation 151 | self.groups = ignore.groups 152 | 153 | self.register_buffer('weight', to_var(ignore.weight.data, requires_grad=True)) 154 | 155 | if ignore.bias is not None: 156 | self.register_buffer('bias', to_var(ignore.bias.data, requires_grad=True)) 157 | else: 158 | self.register_buffer('bias', None) 159 | 160 | def forward(self, x, output_size=None): 161 | output_padding = self._output_padding(x, output_size) 162 | return F.conv_transpose2d(x, self.weight, self.bias, self.stride, self.padding, 163 | output_padding, self.groups, self.dilation) 164 | 165 | def named_leaves(self): 166 | return [('weight', self.weight), ('bias', self.bias)] 167 | 168 | 169 | class MetaBatchNorm2d(MetaModule): 170 | def __init__(self, *args, **kwargs): 171 | super().__init__() 172 | ignore = nn.BatchNorm2d(*args, **kwargs) 173 | 174 | self.num_features = ignore.num_features 175 | self.eps = ignore.eps 176 | self.momentum = ignore.momentum 177 | self.affine = ignore.affine 178 | self.track_running_stats = ignore.track_running_stats 179 | 180 | if self.affine: 181 | self.register_buffer('weight', to_var(ignore.weight.data, requires_grad=True)) 182 | self.register_buffer('bias', to_var(ignore.bias.data, requires_grad=True)) 183 | 184 | if self.track_running_stats: 185 | self.register_buffer('running_mean', torch.zeros(self.num_features)) 186 | self.register_buffer('running_var', torch.ones(self.num_features)) 187 | else: 188 | self.register_parameter('running_mean', None) 189 | self.register_parameter('running_var', None) 190 | 191 | def forward(self, x): 192 | return F.batch_norm(x, self.running_mean, self.running_var, self.weight, self.bias, 193 | self.training or not self.track_running_stats, self.momentum, self.eps) 194 | 195 | def named_leaves(self): 196 | return [('weight', self.weight), ('bias', self.bias)] 197 | 198 | 199 | def _weights_init(m): 200 | classname = m.__class__.__name__ 201 | # print(classname) 202 | if isinstance(m, MetaLinear) or isinstance(m, MetaConv2d): 203 | init.kaiming_normal(m.weight) 204 | 205 | class LambdaLayer(MetaModule): 206 | def __init__(self, lambd): 207 | super(LambdaLayer, self).__init__() 208 | self.lambd = lambd 209 | 210 | def forward(self, x): 211 | return self.lambd(x) 212 | 213 | 214 | class BasicBlock(MetaModule): 215 | expansion = 1 216 | 217 | def __init__(self, in_planes, planes, stride=1, option='A'): 218 | super(BasicBlock, self).__init__() 219 | self.conv1 = MetaConv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 220 | self.bn1 = MetaBatchNorm2d(planes) 221 | self.conv2 = MetaConv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 222 | self.bn2 = MetaBatchNorm2d(planes) 223 | 224 | self.shortcut = nn.Sequential() 225 | if stride != 1 or in_planes != planes: 226 | if option == 'A': 227 | """ 228 | For CIFAR10 ResNet paper uses option A. 229 | """ 230 | self.shortcut = LambdaLayer(lambda x: 231 | F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0)) 232 | elif option == 'B': 233 | self.shortcut = nn.Sequential( 234 | MetaConv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 235 | MetaBatchNorm2d(self.expansion * planes) 236 | ) 237 | 238 | def forward(self, x): 239 | out = F.relu(self.bn1(self.conv1(x))) 240 | out = self.bn2(self.conv2(out)) 241 | out += self.shortcut(x) 242 | out = F.relu(out) 243 | return out 244 | 245 | 246 | class ResNet32(MetaModule): 247 | def __init__(self, num_classes, block=BasicBlock, num_blocks=[5, 5, 5]): 248 | super(ResNet32, self).__init__() 249 | self.in_planes = 16 250 | 251 | self.conv1 = MetaConv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 252 | self.bn1 = MetaBatchNorm2d(16) 253 | self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1) 254 | self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2) 255 | self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2) 256 | self.linear = MetaLinear(64, num_classes) 257 | 258 | self.apply(_weights_init) 259 | 260 | def _make_layer(self, block, planes, num_blocks, stride): 261 | strides = [stride] + [1]*(num_blocks-1) 262 | layers = [] 263 | for stride in strides: 264 | layers.append(block(self.in_planes, planes, stride)) 265 | self.in_planes = planes * block.expansion 266 | 267 | return nn.Sequential(*layers) 268 | 269 | def forward(self, x): 270 | out = F.relu(self.bn1(self.conv1(x))) 271 | out = self.layer1(out) 272 | out = self.layer2(out) 273 | out = self.layer3(out) 274 | out = F.avg_pool2d(out, out.size()[3]) 275 | out = out.view(out.size(0), -1) 276 | out = self.linear(out) 277 | return out 278 | 279 | 280 | 281 | class VNet(MetaModule): 282 | def __init__(self, input, hidden1, output): 283 | super(VNet, self).__init__() 284 | self.linear1 = MetaLinear(input, hidden1) 285 | self.relu1 = nn.ReLU(inplace=True) 286 | self.linear2 = MetaLinear(hidden1, output) 287 | # self.linear3 = MetaLinear(hidden2, output) 288 | 289 | def forward(self, x): 290 | x = self.linear1(x) 291 | x = self.relu1(x) 292 | # x = self.linear2(x) 293 | # x = self.relu1(x) 294 | out = self.linear2(x) 295 | return F.sigmoid(out) 296 | 297 | 298 | class MetaWeightNet(nn.Module): 299 | def __init__(self, beh_num): 300 | super(MetaWeightNet, self).__init__() 301 | 302 | self.beh_num = beh_num 303 | 304 | self.sigmoid = torch.nn.Sigmoid() 305 | self.act = torch.nn.LeakyReLU(negative_slope=args.slope) 306 | self.prelu = torch.nn.PReLU() 307 | self.relu = torch.nn.ReLU() 308 | self.tanhshrink = torch.nn.Tanhshrink() 309 | self.dropout7 = torch.nn.Dropout(args.drop_rate) 310 | self.dropout5 = torch.nn.Dropout(args.drop_rate1) 311 | self.batch_norm = torch.nn.BatchNorm1d(1) 312 | 313 | initializer = nn.init.xavier_uniform_ 314 | 315 | 316 | self.SSL_layer1 = nn.Linear(args.hidden_dim*3, int((args.hidden_dim*3)/2)) 317 | self.SSL_layer2 = nn.Linear(int((args.hidden_dim*3)/2), 1) 318 | self.SSL_layer3 = nn.Linear(args.hidden_dim*2, 1) 319 | 320 | self.RS_layer1 = nn.Linear(args.hidden_dim*3, int((args.hidden_dim*3)/2)) 321 | self.RS_layer2 = nn.Linear(int((args.hidden_dim*3)/2), 1) 322 | self.RS_layer3 = nn.Linear(args.hidden_dim, 1) 323 | 324 | 325 | 326 | self.beh_embedding = nn.Parameter(initializer(torch.empty([beh_num, args.hidden_dim]))).cuda() 327 | 328 | 329 | def forward(self, infoNCELoss_list, behavior_loss_multi_list, user_step_index, user_index_list, user_embeds, user_embed): 330 | 331 | infoNCELoss_list_weights = [None]*self.beh_num 332 | behavior_loss_multi_list_weights = [None]*self.beh_num 333 | for i in range(self.beh_num): 334 | 335 | #retailrocket-------------------------------------------------------------------------------------------------------------------------------------------------------------- 336 | # SSL_input = args.inner_product_mult*torch.cat((infoNCELoss_list[i].unsqueeze(1).repeat(1, args.hidden_dim)*args.inner_product_mult, user_embeds[i][user_step_index]), 1) #[] [1024, 16] 337 | # SSL_input = args.inner_product_mult*torch.cat((SSL_input, user_embed[user_step_index]), 1) 338 | # SSL_input3 = args.inner_product_mult*((infoNCELoss_list[i].unsqueeze(1).repeat(1, args.hidden_dim*2))*torch.cat((user_embeds[i][user_step_index],user_embed[user_step_index]), 1)) 339 | 340 | # infoNCELoss_list_weights[i] = self.dropout7(self.prelu(self.SSL_layer1(SSL_input))) 341 | # infoNCELoss_list_weights[i] = np.sqrt(SSL_input.shape[1])*self.dropout7(self.SSL_layer2(infoNCELoss_list_weights[i]).squeeze()) 342 | 343 | # infoNCELoss_list_weights[i] = self.batch_norm(infoNCELoss_list_weights[i].unsqueeze(1)).squeeze() 344 | # infoNCELoss_list_weights[i] = args.inner_product_mult*self.sigmoid(infoNCELoss_list_weights[i]) 345 | # SSL_weight3 = self.dropout7(self.prelu(self.SSL_layer3(SSL_input3))) 346 | # SSL_weight3 = self.batch_norm(SSL_weight3).squeeze() 347 | 348 | # SSL_weight3 = args.inner_product_mult*self.sigmoid(SSL_weight3) 349 | # infoNCELoss_list_weights[i] = (infoNCELoss_list_weights[i] + SSL_weight3)/2 350 | 351 | # RS_input = args.inner_product_mult*torch.cat((behavior_loss_multi_list[i].unsqueeze(1).repeat(1, args.hidden_dim)*args.inner_product_mult, user_embed[user_index_list[i]]), 1) 352 | # RS_input = args.inner_product_mult*torch.cat((RS_input, user_embeds[i][user_index_list[i]]), 1) 353 | # RS_input3 = args.inner_product_mult*((behavior_loss_multi_list[i].unsqueeze(1).repeat(1, args.hidden_dim))*user_embed[user_index_list[i]]) 354 | 355 | # behavior_loss_multi_list_weights[i] = self.dropout7(self.prelu(self.RS_layer1(RS_input))) 356 | # behavior_loss_multi_list_weights[i] = np.sqrt(RS_input.shape[1])*self.dropout7(self.RS_layer2(behavior_loss_multi_list_weights[i]).squeeze()) 357 | # behavior_loss_multi_list_weights[i] = self.batch_norm(behavior_loss_multi_list_weights[i].unsqueeze(1)) 358 | # behavior_loss_multi_list_weights[i] = args.inner_product_mult*self.sigmoid(behavior_loss_multi_list_weights[i]).squeeze() 359 | # RS_weight3 = self.dropout7(self.prelu(self.RS_layer3(RS_input3))) 360 | # RS_weight3 = self.batch_norm(RS_weight3).squeeze() 361 | # RS_weight3 = args.inner_product_mult*self.sigmoid(RS_weight3).squeeze() 362 | # behavior_loss_multi_list_weights[i] = behavior_loss_multi_list_weights[i] + RS_weight3 363 | #retailrocket-------------------------------------------------------------------------------------------------------------------------------------------------------------- 364 | 365 | #IJCAI,Tmall-------------------------------------------------------------------------------------------------------------------------------------------------------------- 366 | SSL_input = args.inner_product_mult*torch.cat((infoNCELoss_list[i].unsqueeze(1).repeat(1, args.hidden_dim)*args.inner_product_mult, user_embeds[i][user_step_index]), 1) #[] [1024, 16] 367 | SSL_input = args.inner_product_mult*torch.cat((SSL_input, user_embed[user_step_index]), 1) 368 | SSL_input3 = args.inner_product_mult*((infoNCELoss_list[i].unsqueeze(1).repeat(1, args.hidden_dim*2))*torch.cat((user_embeds[i][user_step_index],user_embed[user_step_index]), 1)) 369 | 370 | infoNCELoss_list_weights[i] = self.dropout7(self.prelu(self.SSL_layer1(SSL_input))) 371 | infoNCELoss_list_weights[i] = np.sqrt(SSL_input.shape[1])*self.dropout7(self.SSL_layer2(infoNCELoss_list_weights[i]).squeeze()) 372 | 373 | infoNCELoss_list_weights[i] = self.batch_norm(infoNCELoss_list_weights[i].unsqueeze(1)).squeeze() 374 | infoNCELoss_list_weights[i] = args.inner_product_mult*self.sigmoid(infoNCELoss_list_weights[i]) 375 | SSL_weight3 = self.dropout7(self.prelu(self.SSL_layer3(SSL_input3))) 376 | SSL_weight3 = self.batch_norm(SSL_weight3).squeeze() 377 | 378 | SSL_weight3 = args.inner_product_mult*self.sigmoid(SSL_weight3) 379 | infoNCELoss_list_weights[i] = (infoNCELoss_list_weights[i] + SSL_weight3)/2 380 | 381 | RS_input = args.inner_product_mult*torch.cat((behavior_loss_multi_list[i].unsqueeze(1).repeat(1, args.hidden_dim)*args.inner_product_mult, user_embed[user_index_list[i]]), 1) 382 | RS_input = args.inner_product_mult*torch.cat((RS_input, user_embeds[i][user_index_list[i]]), 1) 383 | RS_input3 = args.inner_product_mult*((behavior_loss_multi_list[i].unsqueeze(1).repeat(1, args.hidden_dim))*user_embed[user_index_list[i]]) 384 | behavior_loss_multi_list_weights[i] = self.dropout7(self.prelu(self.RS_layer1(RS_input))) 385 | behavior_loss_multi_list_weights[i] = np.sqrt(RS_input.shape[1])*self.dropout7(self.RS_layer2(behavior_loss_multi_list_weights[i]).squeeze()) 386 | behavior_loss_multi_list_weights[i] = self.batch_norm(behavior_loss_multi_list_weights[i].unsqueeze(1)) 387 | behavior_loss_multi_list_weights[i] = args.inner_product_mult*self.sigmoid(behavior_loss_multi_list_weights[i]).squeeze() 388 | RS_weight3 = self.dropout7(self.prelu(self.RS_layer3(RS_input3))) 389 | RS_weight3 = self.batch_norm(RS_weight3).squeeze() 390 | RS_weight3 = args.inner_product_mult*self.sigmoid(RS_weight3).squeeze() 391 | behavior_loss_multi_list_weights[i] = behavior_loss_multi_list_weights[i] + RS_weight3 392 | 393 | 394 | return infoNCELoss_list_weights, behavior_loss_multi_list_weights 395 | 396 | 397 | 398 | -------------------------------------------------------------------------------- /Params.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def parse_args(): 5 | parser = argparse.ArgumentParser(description='Model Params') 6 | 7 | 8 | # #for this model 9 | # parser.add_argument('--hidden_dim', default=16, type=int, help='embedding size') 10 | # parser.add_argument('--gnn_layer', default="[16,16,16]", type=str, help='gnn layers: number + dim') 11 | # parser.add_argument('--dataset', default='IJCAI_15', type=str, help='name of dataset') 12 | # parser.add_argument('--point', default='for_meta_hidden_dim', type=str, help='') 13 | # parser.add_argument('--title', default='dim__8', type=str, help='title of model') 14 | # parser.add_argument('--sampNum', default=10, type=int, help='batch size for sampling') 15 | 16 | # #for train 17 | # parser.add_argument('--lr', default=3e-4, type=float, help='learning rate') 18 | # parser.add_argument('--opt_base_lr', default=1e-3, type=float, help='learning rate') 19 | # parser.add_argument('--opt_max_lr', default=2e-3, type=float, help='learning rate') 20 | # parser.add_argument('--opt_weight_decay', default=1e-4, type=float, help='weight decay regularizer') 21 | # parser.add_argument('--meta_opt_base_lr', default=1e-4, type=float, help='learning rate') 22 | # parser.add_argument('--meta_opt_max_lr', default=1e-3, type=float, help='learning rate') 23 | # parser.add_argument('--meta_opt_weight_decay', default=1e-4, type=float, help='weight decay regularizer') 24 | # parser.add_argument('--meta_lr', default=1e-3, type=float, help='_meta_learning rate') 25 | 26 | # parser.add_argument('--batch', default=8192, type=int, help='batch size') 27 | # parser.add_argument('--meta_batch', default=128, type=int, help='batch size') 28 | # parser.add_argument('--SSL_batch', default=30, type=int, help='batch size') 29 | # parser.add_argument('--reg', default=1e-3, type=float, help='weight decay regularizer') 30 | # parser.add_argument('--beta', default=0.005, type=float, help='scale of infoNCELoss') 31 | # parser.add_argument('--epoch', default=300, type=int, help='number of epochs') 32 | # # parser.add_argument('--decay', default=0.96, type=float, help='weight decay rate') 33 | # parser.add_argument('--shoot', default=10, type=int, help='K of top k') 34 | # parser.add_argument('--inner_product_mult', default=1, type=float, help='multiplier for the result') 35 | # parser.add_argument('--drop_rate', default=0.8, type=float, help='drop_rate') 36 | # parser.add_argument('--drop_rate1', default=0.5, type=float, help='drop_rate') 37 | # parser.add_argument('--seed', type=int, default=6) 38 | # parser.add_argument('--slope', type=float, default=0.1) 39 | # parser.add_argument('--patience', type=int, default=100) 40 | # #for save and read 41 | # parser.add_argument('--path', default='/home/ww/Code/MultiBehavior_BASELINE/MB-GCN/Datasets/', type=str, help='data path') 42 | # parser.add_argument('--save_path', default='tem', help='file name to save model and training record') 43 | # parser.add_argument('--load_model', default=None, help='model name to load') 44 | # parser.add_argument('--target', default='buy', type=str, help='target behavior to predict on') 45 | # parser.add_argument('--isload', default=False , type=bool, help='whether load model') 46 | # parser.add_argument('--isJustTest', default=False , type=bool, help='whether load model') 47 | # parser.add_argument('--loadModelPath', default='/home/ww/Code/work3/BSTRec/Model/IJCAI_15/for_meta_hidden_dim_dim__8_IJCAI_15_2021_07_10__14_11_55_lr_0.0003_reg_0.001_batch_size_4096_gnn_layer_[16,16,16].pth', type=str, help='loadModelPath') 48 | # # #Tmall: # loadPath_SSL_meta = "/home/ww/Code/work3/BSTRec/Model/Tmall/for_meta_hidden_dim_dim__8_Tmall_2021_07_08__01_35_54_lr_0.0003_reg_0.001_batch_size_4096_gnn_layer_[16,16,16].pth" 49 | # # #IJCAI_15: # loadPath_SSL_meta = "/home/ww/Code/work3/BSTRec/Model/IJCAI_15/for_meta_hidden_dim_dim__8_IJCAI_15_2021_07_10__14_11_55_lr_0.0003_reg_0.001_batch_size_4096_gnn_layer_[16,16,16].pth" 50 | # # #retailrocket: # loadPath_SSL_meta = "/home/ww/Code/work3/BSTRec/Model/retailrocket/for_meta_hidden_dim_dim__8_retailrocket_2021_07_10__18_35_32_lr_0.0003_reg_0.01_batch_size_1024_gnn_layer_[16,16,16].pth" 51 | 52 | 53 | 54 | # #use less 55 | # # parser.add_argument('--memosize', default=2, type=int, help='memory size') 56 | # parser.add_argument('--head_num', default=4, type=int, help='head_num_of_multihead_attention') 57 | # parser.add_argument('--beta_multi_behavior', default=0.005, type=float, help='scale of infoNCELoss') 58 | # parser.add_argument('--sampNum_slot', default=30, type=int, help='SSL_step') 59 | # parser.add_argument('--SSL_slot', default=1, type=int, help='SSL_step') 60 | # parser.add_argument('--k', default=2, type=float, help='MFB') 61 | # parser.add_argument('--meta_time_rate', default=0.8, type=float, help='gating rate') 62 | # parser.add_argument('--meta_behavior_rate', default=0.8, type=float, help='gating rate') 63 | # parser.add_argument('--meta_slot', default=2, type=int, help='epoch number for each SSL') 64 | # parser.add_argument('--time_slot', default=60*60*24*360, type=float, help='length of time slots') 65 | # parser.add_argument('--hidden_dim_meta', default=16, type=int, help='embedding size') 66 | # # parser.add_argument('--att_head', default=2, type=int, help='number of attention heads') 67 | # # parser.add_argument('--gnn_layer', default=2, type=int, help='number of gnn layers') 68 | # # parser.add_argument('--trnNum', default=10000, type=int, help='number of training instances per epoch') 69 | # # parser.add_argument('--deep_layer', default=0, type=int, help='number of deep layers to make the final prediction') 70 | # # parser.add_argument('--iiweight', default=0.3, type=float, help='weight for ii') 71 | # # parser.add_argument('--graphSampleN', default=10000, type=int, help='use 25000 for training and 200000 for testing, empirically') 72 | # # parser.add_argument('--divSize', default=1000, type=int, help='div size for smallTestEpoch') 73 | # # parser.add_argument('--tstEpoch', default=1, type=int, help='number of epoch to test while training') 74 | # # parser.add_argument('--subUsrSize', default=10, type=int, help='number of item for each sub-user') 75 | # # parser.add_argument('--subUsrDcy', default=0.9, type=float, help='decay factor for sub-users over time') 76 | # # parser.add_argument('--slot', default=0.5, type=float, help='length of time slots') 77 | 78 | 79 | return parser.parse_args() 80 | args = parse_args() 81 | 82 | # 83 | # args.user = 805506#147894 84 | # args.item = 584050#99037 85 | # ML10M 86 | # args.user = 67788 87 | # args.item = 8704 88 | # yelp 89 | # args.user = 19800 90 | # args.item = 22734 91 | 92 | # swap user and item 93 | # tem = args.user 94 | # args.user = args.item 95 | # args.item = tem 96 | 97 | # args.decay_step = args.trn_num 98 | # args.decay_step = args.item 99 | # args.decay_step = args.trnNum 100 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CML 2 | 3 | 4 | 5 | This repository contains PyTorch codes and datasets for the paper: 6 | 7 | > Wei, Wei and Huang, Chao and Xia, Lianghao and Xu, Yong and Zhao, Jiashu and Yin, Dawei. Contrastive Meta Learning with Behavior Multiplicity forRecommendation. Paper in arXiv. 8 | 9 | 10 | ## Introduction 11 | Contrastive Meta Learning (CML) leverages multi-behavior learning paradigm to model diverse and multiplex user-item relationships, as well as tackling the label scarcity problem for target behaviors. The designed multi-behavior contrastive task is to capture the transferable user-item relationships from multi-typed user behavior data heterogeneity. And the proposed meta contrastive encoding scheme allows CML to preserve the personalized multi-behavior characteristics, so as to be reflective of the diverse behavior-aware user preference under a customized self-supervised framework. 12 | 13 | 14 | ## Citation 15 | ``` 16 | @inproceedings{wei2022contrastive, 17 | title={Contrastive meta learning with behavior multiplicity for recommendation}, 18 | author={Wei, Wei and Huang, Chao and Xia, Lianghao and Xu, Yong and Zhao, Jiashu and Yin, Dawei}, 19 | booktitle={Proceedings of the Fifteenth ACM International Conference on Web Search and Data Mining}, 20 | pages={1120--1128}, 21 | year={2022} 22 | } 23 | ``` 24 | 25 | 26 | ## Environment 27 | 28 | The codes of CML are implemented and tested under the following development environment: 29 | 30 | - Python 3.6 31 | - torch==1.8.1+cu111 32 | - scipy==1.6.2 33 | - tqdm==4.61.2 34 | 35 | 36 | 37 | ## Datasets 38 | 39 | #### Raw data: 40 | - IJCAI contest: https://tianchi.aliyun.com/dataset/dataDetail?dataId=47 41 | - Retail Rocket: https://www.kaggle.com/retailrocket/ecommerce-dataset 42 | - Tmall: https://tianchi.aliyun.com/dataset/dataDetail?dataId=649 43 | #### Processed data: 44 | - The processed IJCAI are under the /datasets folder. 45 | 46 | 47 | ## Usage 48 | 49 | The command to train CML on the Tmall/IJCAI/Retailrocket datasets are as follows. The commands specify the hyperparameter settings that generate the reported results in the paper. 50 | 51 | * Tmall 52 | ``` 53 | python main.py --path=./datasets/ --dataset=Tmall --opt_base_lr=1e-3 --opt_max_lr=5e-3 --opt_weight_decay=1e-4 --meta_opt_base_lr=1e-4 --meta_opt_max_lr=2e-3 --meta_opt_weight_decay=1e-4 --meta_lr=1e-3 --batch=8192 --meta_batch=128 --SSL_batch=18 54 | ``` 55 | * IJCAI 56 | ``` 57 | python main.py --path=./datasets/ --dataset=IJCAI_15 --sampNum=10 --opt_base_lr=1e-3 --opt_max_lr=2e-3 --opt_weight_decay=1e-4 --meta_opt_base_lr=1e-4 --meta_opt_max_lr=1e-3 --meta_opt_weight_decay=1e-4 --meta_lr=1e-3 --batch=8192 --meta_batch=128 --SSL_batch=30 58 | ``` 59 | * Retailrocket 60 | ``` 61 | python main.py --path=./datasets/ --dataset='retailrocket' --sampNum=40 --lr=3e-4 --opt_base_lr=1e-4 --opt_max_lr=1e-3 --opt_weight_decay=1e-4 --opt_weight_decay=1e-4 --meta_opt_base_lr=1e-4 --meta_opt_max_lr=1e-3 --meta_opt_weight_decay=1e-3 --meta_lr=1e-3 --batch=2048 --meta_batch=128 --SSL_batch=15 62 | ``` 63 | 64 | 65 | 66 | 72 | 73 | 74 | 75 | 76 | 77 | 79 | 80 | 81 | 82 | 83 | 84 | It will be released again in few days in the optimized code version. 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | -------------------------------------------------------------------------------- /Utils/README.md: -------------------------------------------------------------------------------- 1 | . 2 | -------------------------------------------------------------------------------- /Utils/TimeLogger.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | 3 | logmsg = '' 4 | timemark = dict() 5 | saveDefault = False 6 | def log(msg, save=None, oneline=False): 7 | global logmsg 8 | global saveDefault 9 | time = datetime.datetime.now() 10 | tem = '%s: %s' % (time, msg) 11 | if save != None: 12 | if save: 13 | logmsg += tem + '\n' 14 | elif saveDefault: 15 | logmsg += tem + '\n' 16 | if oneline: 17 | print(tem, end='\r') 18 | else: 19 | print(tem) 20 | 21 | def marktime(marker): 22 | global timemark 23 | timemark[marker] = datetime.datetime.now() 24 | 25 | def SpentTime(marker): 26 | global timemark 27 | if marker not in timemark: 28 | msg = 'LOGGER ERROR, marker', marker, ' not found' 29 | tem = '%s: %s' % (time, msg) 30 | print(tem) 31 | return False 32 | return datetime.datetime.now() - timemark[marker] 33 | 34 | def SpentTooLong(marker, day=0, hour=0, minute=0, second=0): 35 | global timemark 36 | if marker not in timemark: 37 | msg = 'LOGGER ERROR, marker', marker, ' not found' 38 | tem = '%s: %s' % (time, msg) 39 | print(tem) 40 | return False 41 | return datetime.datetime.now() - timemark[marker] >= datetime.timedelta(days=day, hours=hour, minutes=minute, seconds=second) 42 | 43 | if __name__ == '__main__': 44 | log('') -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | . 2 | -------------------------------------------------------------------------------- /graph_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.sparse as sp 3 | from scipy.sparse import * 4 | import torch 5 | 6 | from Params import args 7 | 8 | 9 | def get_use(behaviors_data): 10 | 11 | behavior_mats = {} 12 | 13 | behaviors_data = (behaviors_data != 0) * 1 14 | 15 | behavior_mats['A'] = matrix_to_tensor(normalize_adj(behaviors_data)) 16 | behavior_mats['AT'] = matrix_to_tensor(normalize_adj(behaviors_data.T)) 17 | behavior_mats['A_ori'] = None 18 | 19 | return behavior_mats 20 | 21 | 22 | def normalize_adj(adj): 23 | """Symmetrically normalize adjacency matrix.""" 24 | adj = sp.coo_matrix(adj) 25 | rowsum = np.array(adj.sum(1)) 26 | rowsum_diag = sp.diags(np.power(rowsum+1e-8, -0.5).flatten()) 27 | 28 | colsum = np.array(adj.sum(0)) 29 | colsum_diag = sp.diags(np.power(colsum+1e-8, -0.5).flatten()) 30 | 31 | 32 | return adj 33 | 34 | 35 | def matrix_to_tensor(cur_matrix): 36 | if type(cur_matrix) != sp.coo_matrix: 37 | cur_matrix = cur_matrix.tocoo() 38 | indices = torch.from_numpy(np.vstack((cur_matrix.row, cur_matrix.col)).astype(np.int64)) 39 | values = torch.from_numpy(cur_matrix.data) 40 | shape = torch.Size(cur_matrix.shape) 41 | 42 | return torch.sparse.FloatTensor(indices, values, shape).to(torch.float32).cuda() -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numpy import random 3 | import pickle 4 | from scipy.sparse import csr_matrix 5 | import math 6 | import gc 7 | import time 8 | import random 9 | import datetime 10 | 11 | import torch as t 12 | import torch.nn as nn 13 | import torch.utils.data as dataloader 14 | import torch.nn.functional as F 15 | from torch.nn import init 16 | 17 | import graph_utils 18 | import DataHandler 19 | 20 | 21 | import BGNN 22 | import MV_Net 23 | 24 | from Params import args 25 | from Utils.TimeLogger import log 26 | from tqdm import tqdm 27 | 28 | t.backends.cudnn.benchmark=True 29 | 30 | if t.cuda.is_available(): 31 | use_cuda = True 32 | else: 33 | use_cuda = False 34 | 35 | MAX_FLAG = 0x7FFFFFFF 36 | 37 | now_time = datetime.datetime.now() 38 | modelTime = datetime.datetime.strftime(now_time,'%Y_%m_%d__%H_%M_%S') 39 | 40 | t.autograd.set_detect_anomaly(True) 41 | 42 | class Model(): 43 | def __init__(self): 44 | 45 | 46 | self.trn_file = args.path + args.dataset + '/trn_' 47 | self.tst_file = args.path + args.dataset + '/tst_int' 48 | # self.tst_file = args.path + args.dataset + '/BST_tst_int_59' 49 | #Tmall: 3,4,5,6,8,59 50 | #IJCAI_15: 5,6,8,10,13,53 51 | 52 | # self.meta_multi_file = args.path + args.dataset + '/meta_multi_beh_user_index' 53 | # self.meta_single_file = args.path + args.dataset + '/meta_single_beh_user_index' 54 | self.meta_multi_single_file = args.path + args.dataset + '/meta_multi_single_beh_user_index_shuffle' 55 | # /meta_multi_single_beh_user_index_shuffle 56 | # /new_multi_single 57 | 58 | # self.meta_multi = pickle.load(open(self.meta_multi_file, 'rb')) 59 | # self.meta_single = pickle.load(open(self.meta_single_file, 'rb')) 60 | self.meta_multi_single = pickle.load(open(self.meta_multi_single_file, 'rb')) 61 | 62 | self.t_max = -1 63 | self.t_min = 0x7FFFFFFF 64 | self.time_number = -1 65 | 66 | self.user_num = -1 67 | self.item_num = -1 68 | self.behavior_mats = {} 69 | self.behaviors = [] 70 | self.behaviors_data = {} 71 | 72 | #history 73 | self.train_loss = [] 74 | self.his_hr = [] 75 | self.his_ndcg = [] 76 | gc.collect() # 77 | 78 | self.relu = t.nn.ReLU() 79 | self.sigmoid = t.nn.Sigmoid() 80 | self.curEpoch = 0 81 | 82 | 83 | if args.dataset == 'Tmall': 84 | self.behaviors_SSL = ['pv','fav', 'cart', 'buy'] 85 | self.behaviors = ['pv','fav', 'cart', 'buy'] 86 | # self.behaviors = ['buy'] 87 | elif args.dataset == 'IJCAI_15': 88 | self.behaviors = ['click','fav', 'cart', 'buy'] 89 | # self.behaviors = ['buy'] 90 | self.behaviors_SSL = ['click','fav', 'cart', 'buy'] 91 | 92 | elif args.dataset == 'JD': 93 | self.behaviors = ['review','browse', 'buy'] 94 | self.behaviors_SSL = ['review','browse', 'buy'] 95 | 96 | elif args.dataset == 'retailrocket': 97 | self.behaviors = ['view','cart', 'buy'] 98 | # self.behaviors = ['buy'] 99 | self.behaviors_SSL = ['view','cart', 'buy'] 100 | 101 | 102 | for i in range(0, len(self.behaviors)): 103 | with open(self.trn_file + self.behaviors[i], 'rb') as fs: 104 | data = pickle.load(fs) 105 | self.behaviors_data[i] = data 106 | 107 | if data.get_shape()[0] > self.user_num: 108 | self.user_num = data.get_shape()[0] 109 | if data.get_shape()[1] > self.item_num: 110 | self.item_num = data.get_shape()[1] 111 | 112 | 113 | if data.data.max() > self.t_max: 114 | self.t_max = data.data.max() 115 | if data.data.min() < self.t_min: 116 | self.t_min = data.data.min() 117 | 118 | 119 | if self.behaviors[i]==args.target: 120 | self.trainMat = data 121 | self.trainLabel = 1*(self.trainMat != 0) 122 | self.labelP = np.squeeze(np.array(np.sum(self.trainLabel, axis=0))) 123 | 124 | 125 | time = datetime.datetime.now() 126 | print("Start building: ", time) 127 | for i in range(0, len(self.behaviors)): 128 | self.behavior_mats[i] = graph_utils.get_use(self.behaviors_data[i]) 129 | time = datetime.datetime.now() 130 | print("End building:", time) 131 | 132 | 133 | print("user_num: ", self.user_num) 134 | print("item_num: ", self.item_num) 135 | print("\n") 136 | 137 | 138 | #---------------------------------------------------------------------------------------------->>>>> 139 | #train_data 140 | train_u, train_v = self.trainMat.nonzero() 141 | train_data = np.hstack((train_u.reshape(-1,1), train_v.reshape(-1,1))).tolist() 142 | train_dataset = DataHandler.RecDataset_beh(self.behaviors, train_data, self.item_num, self.behaviors_data, True) 143 | self.train_loader = dataloader.DataLoader(train_dataset, batch_size=args.batch, shuffle=True, num_workers=4, pin_memory=True) 144 | 145 | #valid_data 146 | 147 | 148 | # test_data 149 | with open(self.tst_file, 'rb') as fs: 150 | data = pickle.load(fs) 151 | 152 | test_user = np.array([idx for idx, i in enumerate(data) if i is not None]) 153 | test_item = np.array([i for idx, i in enumerate(data) if i is not None]) 154 | # tstUsrs = np.reshape(np.argwhere(data!=None), [-1]) 155 | test_data = np.hstack((test_user.reshape(-1,1), test_item.reshape(-1,1))).tolist() 156 | # testbatch = np.maximum(1, args.batch * args.sampNum 157 | test_dataset = DataHandler.RecDataset(test_data, self.item_num, self.trainMat, 0, False) 158 | self.test_loader = dataloader.DataLoader(test_dataset, batch_size=args.batch, shuffle=False, num_workers=4, pin_memory=True) 159 | # -------------------------------------------------------------------------------------------------->>>>> 160 | 161 | def prepareModel(self): 162 | self.modelName = self.getModelName() 163 | # self.setRandomSeed() 164 | self.gnn_layer = eval(args.gnn_layer) 165 | self.hidden_dim = args.hidden_dim 166 | 167 | 168 | if args.isload == True: 169 | self.loadModel(args.loadModelPath) 170 | else: 171 | self.model = BGNN.myModel(self.user_num, self.item_num, self.behaviors, self.behavior_mats).cuda() 172 | self.meta_weight_net = MV_Net.MetaWeightNet(len(self.behaviors)).cuda() 173 | 174 | 175 | 176 | # #IJCAI_15 177 | # self.opt = t.optim.AdamW(self.model.parameters(), lr = args.lr, weight_decay = args.opt_weight_decay) 178 | # self.meta_opt = t.optim.AdamW(self.meta_weight_net.parameters(), lr = args.meta_lr, weight_decay=args.meta_opt_weight_decay) 179 | # # self.meta_opt = t.optim.RMSprop(self.meta_weight_net.parameters(), lr = args.meta_lr, weight_decay=args.meta_opt_weight_decay, momentum=0.95, centered=True) 180 | # self.scheduler = t.optim.lr_scheduler.CyclicLR(self.opt, args.opt_base_lr, args.opt_max_lr, step_size_up=5, step_size_down=10, mode='triangular', gamma=0.99, scale_fn=None, scale_mode='cycle', cycle_momentum=False, base_momentum=0.8, max_momentum=0.9, last_epoch=-1) 181 | # self.meta_scheduler = t.optim.lr_scheduler.CyclicLR(self.meta_opt, args.meta_opt_base_lr, args.meta_opt_max_lr, step_size_up=2, step_size_down=3, mode='triangular', gamma=0.98, scale_fn=None, scale_mode='cycle', cycle_momentum=False, base_momentum=0.9, max_momentum=0.99, last_epoch=-1) 182 | # # 183 | 184 | 185 | #Tmall 186 | self.opt = t.optim.AdamW(self.model.parameters(), lr = args.lr, weight_decay = args.opt_weight_decay) 187 | self.meta_opt = t.optim.AdamW(self.meta_weight_net.parameters(), lr = args.meta_lr, weight_decay=args.meta_opt_weight_decay) 188 | # self.meta_opt = t.optim.RMSprop(self.meta_weight_net.parameters(), lr = args.meta_lr, weight_decay=args.meta_opt_weight_decay, momentum=0.95, centered=True) 189 | self.scheduler = t.optim.lr_scheduler.CyclicLR(self.opt, args.opt_base_lr, args.opt_max_lr, step_size_up=5, step_size_down=10, mode='triangular', gamma=0.99, scale_fn=None, scale_mode='cycle', cycle_momentum=False, base_momentum=0.8, max_momentum=0.9, last_epoch=-1) 190 | self.meta_scheduler = t.optim.lr_scheduler.CyclicLR(self.meta_opt, args.meta_opt_base_lr, args.meta_opt_max_lr, step_size_up=3, step_size_down=7, mode='triangular', gamma=0.98, scale_fn=None, scale_mode='cycle', cycle_momentum=False, base_momentum=0.9, max_momentum=0.99, last_epoch=-1) 191 | # 0.993 192 | 193 | # # retailrocket 194 | # self.opt = t.optim.AdamW(self.model.parameters(), lr = args.lr, weight_decay = args.opt_weight_decay) 195 | # # self.meta_opt = t.optim.AdamW(self.meta_weight_net.parameters(), lr = args.meta_lr, weight_decay=args.meta_opt_weight_decay) 196 | # self.meta_opt = t.optim.SGD(self.meta_weight_net.parameters(), lr = args.meta_lr, weight_decay=args.meta_opt_weight_decay, momentum=0.95, nesterov=True) 197 | # self.scheduler = t.optim.lr_scheduler.CyclicLR(self.opt, args.opt_base_lr, args.opt_max_lr, step_size_up=1, step_size_down=2, mode='triangular', gamma=0.99, scale_fn=None, scale_mode='cycle', cycle_momentum=False, base_momentum=0.8, max_momentum=0.9, last_epoch=-1) 198 | # self.meta_scheduler = t.optim.lr_scheduler.CyclicLR(self.meta_opt, args.meta_opt_base_lr, args.meta_opt_max_lr, step_size_up=1, step_size_down=2, mode='triangular', gamma=0.99, scale_fn=None, scale_mode='cycle', cycle_momentum=False, base_momentum=0.9, max_momentum=0.99, last_epoch=-1) 199 | # # exp_range 200 | 201 | 202 | if use_cuda: 203 | self.model = self.model.cuda() 204 | 205 | def innerProduct(self, u, i, j): 206 | pred_i = t.sum(t.mul(u,i), dim=1)*args.inner_product_mult 207 | pred_j = t.sum(t.mul(u,j), dim=1)*args.inner_product_mult 208 | return pred_i, pred_j 209 | 210 | def SSL(self, user_embeddings, item_embeddings, target_user_embeddings, target_item_embeddings, user_step_index): 211 | def row_shuffle(embedding): 212 | corrupted_embedding = embedding[t.randperm(embedding.size()[0])] 213 | return corrupted_embedding 214 | def row_column_shuffle(embedding): 215 | corrupted_embedding = embedding[t.randperm(embedding.size()[0])] 216 | corrupted_embedding = corrupted_embedding[:,t.randperm(corrupted_embedding.size()[1])] 217 | return corrupted_embedding 218 | def score(x1, x2): 219 | return t.sum(t.mul(x1, x2), 1) 220 | 221 | def neg_sample_pair(x1, x2, τ = 0.05): 222 | for i in range(x1.shape[0]): 223 | index_set = set(np.arange(x1.shape[0])) 224 | index_set.remove(i) 225 | index_set_neg = t.as_tensor(np.array(list(index_set))).long().cuda() 226 | 227 | x_pos = x1[i].repeat(x1.shape[0]-1, 1) 228 | x_neg = x2[index_set] 229 | 230 | if i==0: 231 | x_pos_all = x_pos 232 | x_neg_all = x_neg 233 | else: 234 | x_pos_all = t.cat((x_pos_all, x_pos), 0) 235 | x_neg_all = t.cat((x_neg_all, x_neg), 0) 236 | x_pos_all = t.as_tensor(x_pos_all) #[9900, 100] 237 | x_neg_all = t.as_tensor(x_neg_all) #[9900, 100] 238 | 239 | return x_pos_all, x_neg_all 240 | 241 | def one_neg_sample_pair_index(i, step_index, embedding1, embedding2): 242 | 243 | index_set = set(np.array(step_index)) 244 | index_set.remove(i.item()) 245 | neg2_index = t.as_tensor(np.array(list(index_set))).long().cuda() 246 | 247 | neg1_index = t.ones((2,), dtype=t.long) 248 | neg1_index = neg1_index.new_full((len(index_set),), i) 249 | 250 | neg_score_pre = t.sum(compute(embedding1, embedding2, neg1_index, neg2_index).squeeze()) 251 | return neg_score_pre 252 | 253 | def multi_neg_sample_pair_index(batch_index, step_index, embedding1, embedding2): #small, big, target, beh: [100], [1024], [31882, 16], [31882, 16] 254 | 255 | index_set = set(np.array(step_index.cpu())) 256 | batch_index_set = set(np.array(batch_index.cpu())) 257 | neg2_index_set = index_set - batch_index_set #beh 258 | neg2_index = t.as_tensor(np.array(list(neg2_index_set))).long().cuda() #[910] 259 | neg2_index = t.unsqueeze(neg2_index, 0) #[1, 910] 260 | neg2_index = neg2_index.repeat(len(batch_index), 1) #[100, 910] 261 | neg2_index = t.reshape(neg2_index, (1, -1)) #[1, 91000] 262 | neg2_index = t.squeeze(neg2_index) #[91000] 263 | #target 264 | neg1_index = batch_index.long().cuda() #[100] 265 | neg1_index = t.unsqueeze(neg1_index, 1) #[100, 1] 266 | neg1_index = neg1_index.repeat(1, len(neg2_index_set)) #[100, 910] 267 | neg1_index = t.reshape(neg1_index, (1, -1)) #[1, 91000] 268 | neg1_index = t.squeeze(neg1_index) #[91000] 269 | 270 | neg_score_pre = t.sum(compute(embedding1, embedding2, neg1_index, neg2_index).squeeze().view(len(batch_index), -1), -1) #[91000,1]==>[91000]==>[100, 910]==>[100] 271 | return neg_score_pre #[100] 272 | 273 | def compute(x1, x2, neg1_index=None, neg2_index=None, τ = 0.05): #[1024, 16], [1024, 16] 274 | 275 | if neg1_index!=None: 276 | x1 = x1[neg1_index] 277 | x2 = x2[neg2_index] 278 | 279 | N = x1.shape[0] 280 | D = x1.shape[1] 281 | 282 | x1 = x1 283 | x2 = x2 284 | 285 | scores = t.exp(t.div(t.bmm(x1.view(N, 1, D), x2.view(N, D, 1)).view(N, 1), np.power(D, 1)+1e-8)) #[1024, 1] 286 | 287 | return scores 288 | def single_infoNCE_loss_simple(embedding1, embedding2): 289 | pos = score(embedding1, embedding2) #[100] 290 | neg1 = score(embedding2, row_column_shuffle(embedding1)) 291 | one = t.cuda.FloatTensor(neg1.shape[0]).fill_(1) #[100] 292 | # one = zeros = t.ones(neg1.shape[0]) 293 | con_loss = t.sum(-t.log(1e-8 + t.sigmoid(pos))-t.log(1e-8 + (one - t.sigmoid(neg1)))) 294 | return con_loss 295 | 296 | #use_less 297 | def single_infoNCE_loss(embedding1, embedding2): 298 | N = embedding1.shape[0] 299 | D = embedding1.shape[1] 300 | 301 | pos_score = compute(embedding1, embedding2).squeeze() #[100, 1] 302 | 303 | neg_x1, neg_x2 = neg_sample_pair(embedding1, embedding2) #[9900, 100], [9900, 100] 304 | neg_score = t.sum(compute(neg_x1, neg_x2).view(N, (N-1)), dim=1) #[100] 305 | con_loss = -t.log(1e-8 +t.div(pos_score, neg_score)) 306 | con_loss = t.mean(con_loss) 307 | return max(0, con_loss) 308 | 309 | def single_infoNCE_loss_one_by_one(embedding1, embedding2, step_index): #target, beh 310 | N = step_index.shape[0] 311 | D = embedding1.shape[1] 312 | 313 | pos_score = compute(embedding1[step_index], embedding2[step_index]).squeeze() #[1024] 314 | neg_score = t.zeros((N,), dtype = t.float64).cuda() #[1024] 315 | 316 | #-------------------------------------------------multi version----------------------------------------------------- 317 | steps = int(np.ceil(N / args.SSL_batch)) #separate the batch to smaller one 318 | for i in range(steps): 319 | st = i * args.SSL_batch 320 | ed = min((i+1) * args.SSL_batch, N) 321 | batch_index = step_index[st: ed] 322 | 323 | neg_score_pre = multi_neg_sample_pair_index(batch_index, step_index, embedding1, embedding2) 324 | if i ==0: 325 | neg_score = neg_score_pre 326 | else: 327 | neg_score = t.cat((neg_score, neg_score_pre), 0) 328 | #-------------------------------------------------multi version----------------------------------------------------- 329 | 330 | con_loss = -t.log(1e-8 +t.div(pos_score, neg_score+1e-8)) #[1024]/[1024]==>1024 331 | 332 | 333 | assert not t.any(t.isnan(con_loss)) 334 | assert not t.any(t.isinf(con_loss)) 335 | 336 | return t.where(t.isnan(con_loss), t.full_like(con_loss, 0+1e-8), con_loss) 337 | 338 | user_con_loss_list = [] 339 | item_con_loss_list = [] 340 | 341 | SSL_len = int(user_step_index.shape[0]/10) 342 | user_step_index = t.as_tensor(np.random.choice(user_step_index.cpu(), size=SSL_len, replace=False, p=None)).cuda() 343 | 344 | for i in range(len(self.behaviors_SSL)): 345 | 346 | user_con_loss_list.append(single_infoNCE_loss_one_by_one(user_embeddings[-1], user_embeddings[i], user_step_index)) 347 | 348 | user_con_losss = t.stack(user_con_loss_list, dim=0) 349 | 350 | return user_con_loss_list, user_step_index #4*[1024] 351 | 352 | def run(self): 353 | 354 | self.prepareModel() 355 | if args.isload == True: 356 | print("----------------------pre test:") 357 | HR, NDCG = self.testEpoch(self.test_loader) 358 | print(f"HR: {HR} , NDCG: {NDCG}") 359 | log('Model Prepared') 360 | 361 | 362 | cvWait = 0 363 | self.best_HR = 0 364 | self.best_NDCG = 0 365 | flag = 0 366 | 367 | self.user_embed = None 368 | self.item_embed = None 369 | self.user_embeds = None 370 | self.item_embeds = None 371 | 372 | 373 | print("Test before train:") 374 | HR, NDCG = self.testEpoch(self.test_loader) 375 | 376 | for e in range(self.curEpoch, args.epoch+1): 377 | self.curEpoch = e 378 | 379 | self.meta_flag = 0 380 | if e%args.meta_slot == 0: 381 | self.meta_flag=1 382 | 383 | 384 | log("*****************Start epoch: %d ************************"%e) 385 | 386 | if args.isJustTest == False: 387 | epoch_loss, user_embed, item_embed, user_embeds, item_embeds = self.trainEpoch() 388 | self.train_loss.append(epoch_loss) 389 | print(f"epoch {e/args.epoch}, epoch loss{epoch_loss}") 390 | self.train_loss.append(epoch_loss) 391 | else: 392 | break 393 | 394 | HR, NDCG = self.testEpoch(self.test_loader) 395 | self.his_hr.append(HR) 396 | self.his_ndcg.append(NDCG) 397 | 398 | self.scheduler.step() 399 | self.meta_scheduler.step() 400 | 401 | if HR > self.best_HR: 402 | self.best_HR = HR 403 | self.best_epoch = self.curEpoch 404 | cvWait = 0 405 | print("--------------------------------------------------------------------------------------------------------------------------best_HR", self.best_HR) 406 | # print("--------------------------------------------------------------------------------------------------------------------------NDCG", self.best_NDCG) 407 | self.user_embed = user_embed 408 | self.item_embed = item_embed 409 | self.user_embeds = user_embeds 410 | self.item_embeds = item_embeds 411 | 412 | self.saveHistory() 413 | self.saveModel() 414 | 415 | 416 | 417 | if NDCG > self.best_NDCG: 418 | self.best_NDCG = NDCG 419 | self.best_epoch = self.curEpoch 420 | cvWait = 0 421 | # print("--------------------------------------------------------------------------------------------------------------------------HR", self.best_HR) 422 | print("--------------------------------------------------------------------------------------------------------------------------best_NDCG", self.best_NDCG) 423 | self.user_embed = user_embed 424 | self.item_embed = item_embed 425 | self.user_embeds = user_embeds 426 | self.item_embeds = item_embeds 427 | 428 | self.saveHistory() 429 | self.saveModel() 430 | 431 | 432 | 433 | if (HR