├── LICENSE ├── README.md ├── asset └── diffurec_framework.png ├── datasets ├── data │ ├── amazon_beauty │ │ ├── dataset.pkl │ │ └── readme.md │ ├── ml-1m │ │ ├── dataset.pkl │ │ └── readme.md │ ├── steam │ │ ├── dataset.pkl │ │ └── readme.md │ └── toys │ │ └── dataset.pkl └── readme.md └── src ├── diffurec.py ├── main.py ├── model.py ├── step_sample.py ├── trainer.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Information Retrieval Group, Wuhan University, China 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DiffuRec 2 | 3 | This is a PyTorch implementation for our [DiffuRec](https://arxiv.org/abs/2304.00686) paper: 4 | 5 | > Zihao Li, Aixin Sun, and Chenliang Li. 2023. DiffuRec: A Diffusion Model for Sequential Recommendation. ACM Trans. Inf. Syst. 42, 3, Article 66 (May 2024), 28 pages. https://doi.org/10.1145/3631116 6 | 7 | ## Overview 8 | Mainstream solutions to Sequential Recommendation (SR) represent items with fixed vectors. These vectors have limited capability in capturing items’ latent aspects and users' diverse preferences. As a new generative paradigm, Diffusion models have achieved excellent performance in areas like computer vision and natural language processing. To our understanding, its unique merit in representation generation well fits the problem setting of sequential recommendation. In this paper, we make the very first attempt to adapt Diffusion model to SR and propose DiffuRec, for item representation construction and uncertainty injection. Rather than modeling item representations as fixed vectors, we represent them as distributions in DiffuRec, which reflect user's multiple interests and item's various aspects adaptively. In diffusion phase, DiffuRec corrupts the target item embedding into a Gaussian distribution via noise adding, which is further applied for sequential item distribution representation generation and uncertainty injection. Afterward, the item representation is fed into an Approximator for target item representation reconstruction. In reverse phase, based on user's historical interaction behaviors, we reverse a Gaussian noise into the target item representation, then apply a rounding operation for target item prediction. Experiments over four datasets show that DiffuRec outperforms strong baselines by a large margin. 9 | 10 | ![Diffurec](asset/diffurec_framework.png) 11 | 12 | ## Requirements 13 | - Python 3.8.11 14 | - PyTorch 1.8.0 15 | - numpy 1.23.4 16 | 17 | Our code has been tested running under a Linux desktop with NVIDIA GeForce RTX 3090 GPU and Intel Xeon CPU E5-2680 v3. 18 | 19 | ## Usage 20 | 21 | 0. Clone this repo 22 | 23 | ``` 24 | git clone https://github.com/WHUIR/DiffuRec.git 25 | ``` 26 | 27 | 2. You can run the below command for model training and evaluation. 28 | ``` 29 | python main.py --dataset amazon_beauty 30 | ``` 31 | 32 | ## Citation 33 | Please cite the following paper corresponding to the repository: 34 | ``` 35 | @article{10.1145/3631116, 36 | author = {Li, Zihao and Sun, Aixin and Li, Chenliang}, 37 | title = {DiffuRec: A Diffusion Model for Sequential Recommendation}, 38 | year = {2023}, 39 | issue_date = {May 2024}, 40 | publisher = {Association for Computing Machinery}, 41 | address = {New York, NY, USA}, 42 | volume = {42}, 43 | number = {3}, 44 | issn = {1046-8188}, 45 | doi = {10.1145/3631116}, 46 | journal = {ACM Trans. Inf. Syst.}, 47 | } 48 | ``` 49 | 50 | ## Acknowledgements 51 | 52 | [TimiRec](https://github.com/THUwangcy/ReChorus/tree/CIKM22), [SVAE](https://github.com/noveens/svae_cf), [ACVAE](https://github.com/ACVAE/ACVAE-PyTorch) and [STOSA](https://github.com/zfan20/STOSA). 53 | -------------------------------------------------------------------------------- /asset/diffurec_framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WHUIR/DiffuRec/f6f70c9d78477591742c71c9ed83e5f1269530e8/asset/diffurec_framework.png -------------------------------------------------------------------------------- /datasets/data/amazon_beauty/dataset.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WHUIR/DiffuRec/f6f70c9d78477591742c71c9ed83e5f1269530e8/datasets/data/amazon_beauty/dataset.pkl -------------------------------------------------------------------------------- /datasets/data/amazon_beauty/readme.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /datasets/data/ml-1m/dataset.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WHUIR/DiffuRec/f6f70c9d78477591742c71c9ed83e5f1269530e8/datasets/data/ml-1m/dataset.pkl -------------------------------------------------------------------------------- /datasets/data/ml-1m/readme.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /datasets/data/steam/dataset.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WHUIR/DiffuRec/f6f70c9d78477591742c71c9ed83e5f1269530e8/datasets/data/steam/dataset.pkl -------------------------------------------------------------------------------- /datasets/data/steam/readme.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /datasets/data/toys/dataset.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WHUIR/DiffuRec/f6f70c9d78477591742c71c9ed83e5f1269530e8/datasets/data/toys/dataset.pkl -------------------------------------------------------------------------------- /datasets/readme.md: -------------------------------------------------------------------------------- 1 | # Data Processing 2 | 3 | Our data preprocessing methods and results are consistent with [ICLRRec](https://github.com/salesforce/ICLRec/tree/master/data) and [DuoRec](https://github.com/RuihongQiu/DuoRec), please refer them for details. 4 | 5 | -------------------------------------------------------------------------------- /src/diffurec.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch as th 3 | from step_sample import create_named_schedule_sampler 4 | import numpy as np 5 | import math 6 | import torch 7 | import torch.nn.functional as F 8 | 9 | 10 | def _extract_into_tensor(arr, timesteps, broadcast_shape): 11 | """ 12 | Extract values from a 1-D numpy array for a batch of indices. 13 | 14 | :param arr: the 1-D numpy array. 15 | :param timesteps: a tensor of indices into the array to extract. 16 | :param broadcast_shape: a larger shape of K dimensions with the batch 17 | dimension equal to the length of timesteps. 18 | :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. 19 | """ 20 | 21 | res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() 22 | while len(res.shape) < len(broadcast_shape): 23 | res = res[..., None] 24 | return res.expand(broadcast_shape) 25 | 26 | 27 | def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): 28 | """ 29 | Get a pre-defined beta schedule for the given name. 30 | The beta schedule library consists of beta schedules which remain similar in the limit of num_diffusion_timesteps. Beta schedules may be added, but should not be removed or changed once they are committed to maintain backwards compatibility. 31 | """ 32 | if schedule_name == "linear": 33 | # Linear schedule from Ho et al, extended to work for any number of 34 | # diffusion steps. 35 | scale = 1000 / num_diffusion_timesteps 36 | beta_start = scale * 0.0001 37 | beta_end = scale * 0.02 38 | return np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) 39 | elif schedule_name == "cosine": 40 | return betas_for_alpha_bar(num_diffusion_timesteps, lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,) 41 | elif schedule_name == 'sqrt': 42 | return betas_for_alpha_bar(num_diffusion_timesteps,lambda t: 1-np.sqrt(t + 0.0001), ) 43 | elif schedule_name == "trunc_cos": 44 | return betas_for_alpha_bar_left(num_diffusion_timesteps, lambda t: np.cos((t + 0.1) / 1.1 * np.pi / 2) ** 2,) 45 | elif schedule_name == 'trunc_lin': 46 | scale = 1000 / num_diffusion_timesteps 47 | beta_start = scale * 0.0001 + 0.01 48 | beta_end = scale * 0.02 + 0.01 49 | if beta_end > 1: 50 | beta_end = scale * 0.001 + 0.01 51 | return np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) 52 | elif schedule_name == 'pw_lin': 53 | scale = 1000 / num_diffusion_timesteps 54 | beta_start = scale * 0.0001 + 0.01 55 | beta_mid = scale * 0.0001 #scale * 0.02 56 | beta_end = scale * 0.02 57 | first_part = np.linspace(beta_start, beta_mid, 10, dtype=np.float64) 58 | second_part = np.linspace(beta_mid, beta_end, num_diffusion_timesteps - 10 , dtype=np.float64) 59 | return np.concatenate([first_part, second_part]) 60 | else: 61 | raise NotImplementedError(f"unknown beta schedule: {schedule_name}") 62 | 63 | 64 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 65 | """ 66 | Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. 67 | :param num_diffusion_timesteps: the number of betas to produce. 68 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and produces the cumulative product of (1-beta) up to that part of the diffusion process. 69 | :param max_beta: the maximum beta to use; use values lower than 1 to prevent singularities. 70 | """ 71 | betas = [] 72 | for i in range(num_diffusion_timesteps): ## 2000 73 | t1 = i / num_diffusion_timesteps 74 | t2 = (i + 1) / num_diffusion_timesteps 75 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 76 | return np.array(betas) 77 | 78 | 79 | def betas_for_alpha_bar_left(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 80 | """ 81 | Create a beta schedule that discretizes the given alpha_t_bar function, but shifts towards left interval starting from 0 82 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 83 | 84 | :param num_diffusion_timesteps: the number of betas to produce. 85 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 86 | produces the cumulative product of (1-beta) up to that 87 | part of the diffusion process. 88 | :param max_beta: the maximum beta to use; use values lower than 1 to 89 | prevent singularities. 90 | """ 91 | betas = [] 92 | betas.append(min(1-alpha_bar(0), max_beta)) 93 | for i in range(num_diffusion_timesteps-1): 94 | t1 = i / num_diffusion_timesteps 95 | t2 = (i + 1) / num_diffusion_timesteps 96 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 97 | return np.array(betas) 98 | 99 | 100 | def space_timesteps(num_timesteps, section_counts): 101 | """ 102 | Create a list of timesteps to use from an original diffusion process, 103 | given the number of timesteps we want to take from equally-sized portions 104 | of the original process. 105 | 106 | For example, if there's 300 timesteps and the section counts are [10,15,20] 107 | then the first 100 timesteps are strided to be 10 timesteps, the second 100 108 | are strided to be 15 timesteps, and the final 100 are strided to be 20. 109 | 110 | If the stride is a string starting with "ddim", then the fixed striding 111 | from the DDIM paper is used, and only one section is allowed. 112 | 113 | :param num_timesteps: the number of diffusion steps in the original 114 | process to divide up. 115 | :param section_counts: either a list of numbers, or a string containing 116 | comma-separated numbers, indicating the step count 117 | per section. As a special case, use "ddimN" where N 118 | is a number of steps to use the striding from the 119 | DDIM paper. 120 | :return: a set of diffusion steps from the original process to use. 121 | """ 122 | if isinstance(section_counts, str): 123 | if section_counts.startswith("ddim"): 124 | desired_count = int(section_counts[len("ddim") :]) 125 | for i in range(1, num_timesteps): 126 | if len(range(0, num_timesteps, i)) == desired_count: 127 | return set(range(0, num_timesteps, i)) 128 | raise ValueError( 129 | f"cannot create exactly {num_timesteps} steps with an integer stride" 130 | ) 131 | section_counts = [int(x) for x in section_counts.split(",")] 132 | size_per = num_timesteps // len(section_counts) 133 | extra = num_timesteps % len(section_counts) 134 | start_idx = 0 135 | all_steps = [] 136 | for i, section_count in enumerate(section_counts): 137 | size = size_per + (1 if i < extra else 0) 138 | if size < section_count: 139 | raise ValueError( 140 | f"cannot divide section of {size} steps into {section_count}" 141 | ) 142 | if section_count <= 1: 143 | frac_stride = 1 144 | else: 145 | frac_stride = (size - 1) / (section_count - 1) 146 | cur_idx = 0.0 147 | taken_steps = [] 148 | for _ in range(section_count): 149 | taken_steps.append(start_idx + round(cur_idx)) 150 | cur_idx += frac_stride 151 | all_steps += taken_steps 152 | start_idx += size 153 | return set(all_steps) 154 | 155 | 156 | class SiLU(nn.Module): 157 | def forward(self, x): 158 | return x * th.sigmoid(x) 159 | 160 | 161 | class LayerNorm(nn.Module): 162 | def __init__(self, hidden_size, eps=1e-12): 163 | """Construct a layernorm module in the TF style (epsilon inside the square root). 164 | """ 165 | super(LayerNorm, self).__init__() 166 | self.weight = nn.Parameter(torch.ones(hidden_size)) 167 | self.bias = nn.Parameter(torch.zeros(hidden_size)) 168 | self.variance_epsilon = eps 169 | 170 | def forward(self, x): 171 | u = x.mean(-1, keepdim=True) 172 | s = (x - u).pow(2).mean(-1, keepdim=True) 173 | x = (x - u) / torch.sqrt(s + self.variance_epsilon) 174 | return self.weight * x + self.bias 175 | 176 | 177 | class SublayerConnection(nn.Module): 178 | """ 179 | A residual connection followed by a layer norm. 180 | Note for code simplicity the norm is first as opposed to last. 181 | """ 182 | 183 | def __init__(self, hidden_size, dropout): 184 | super(SublayerConnection, self).__init__() 185 | self.norm = LayerNorm(hidden_size) 186 | self.dropout = nn.Dropout(dropout) 187 | 188 | def forward(self, x, sublayer): 189 | "Apply residual connection to any sublayer with the same size." 190 | return x + self.dropout(sublayer(self.norm(x))) 191 | 192 | 193 | class PositionwiseFeedForward(nn.Module): 194 | "Implements FFN equation." 195 | 196 | def __init__(self, hidden_size, dropout=0.1): 197 | super(PositionwiseFeedForward, self).__init__() 198 | self.w_1 = nn.Linear(hidden_size, hidden_size*4) 199 | self.w_2 = nn.Linear(hidden_size*4, hidden_size) 200 | self.dropout = nn.Dropout(dropout) 201 | self.init_weights() 202 | 203 | def init_weights(self): 204 | nn.init.xavier_normal_(self.w_1.weight) 205 | nn.init.xavier_normal_(self.w_2.weight) 206 | 207 | def forward(self, hidden): 208 | hidden = self.w_1(hidden) 209 | activation = 0.5 * hidden * (1 + torch.tanh(math.sqrt(2 / math.pi) * (hidden + 0.044715 * torch.pow(hidden, 3)))) 210 | return self.w_2(self.dropout(activation)) 211 | 212 | 213 | class MultiHeadedAttention(nn.Module): 214 | def __init__(self, heads, hidden_size, dropout): 215 | super().__init__() 216 | assert hidden_size % heads == 0 217 | self.size_head = hidden_size // heads 218 | self.num_heads = heads 219 | self.linear_layers = nn.ModuleList([nn.Linear(hidden_size, hidden_size) for _ in range(3)]) 220 | self.w_layer = nn.Linear(hidden_size, hidden_size) 221 | self.dropout = nn.Dropout(p=dropout) 222 | self.init_weights() 223 | 224 | def init_weights(self): 225 | nn.init.xavier_normal_(self.w_layer.weight) 226 | 227 | def forward(self, q, k, v, mask=None): 228 | batch_size = q.shape[0] 229 | q, k, v = [l(x).view(batch_size, -1, self.num_heads, self.size_head).transpose(1, 2) for l, x in zip(self.linear_layers, (q, k, v))] 230 | corr = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(q.size(-1)) 231 | 232 | if mask is not None: 233 | mask = mask.unsqueeze(1).repeat([1, corr.shape[1], 1]).unsqueeze(-1).repeat([1,1,1,corr.shape[-1]]) 234 | corr = corr.masked_fill(mask == 0, -1e9) 235 | prob_attn = F.softmax(corr, dim=-1) 236 | if self.dropout is not None: 237 | prob_attn = self.dropout(prob_attn) 238 | hidden = torch.matmul(prob_attn, v) 239 | hidden = self.w_layer(hidden.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.size_head)) 240 | return hidden 241 | 242 | 243 | class TransformerBlock(nn.Module): 244 | def __init__(self, hidden_size, attn_heads, dropout): 245 | super(TransformerBlock, self).__init__() 246 | self.attention = MultiHeadedAttention(heads=attn_heads, hidden_size=hidden_size, dropout=dropout) 247 | self.feed_forward = PositionwiseFeedForward(hidden_size=hidden_size, dropout=dropout) 248 | self.input_sublayer = SublayerConnection(hidden_size=hidden_size, dropout=dropout) 249 | self.output_sublayer = SublayerConnection(hidden_size=hidden_size, dropout=dropout) 250 | self.dropout = nn.Dropout(p=dropout) 251 | 252 | def forward(self, hidden, mask): 253 | hidden = self.input_sublayer(hidden, lambda _hidden: self.attention.forward(_hidden, _hidden, _hidden, mask=mask)) 254 | hidden = self.output_sublayer(hidden, self.feed_forward) 255 | return self.dropout(hidden) 256 | 257 | 258 | class Transformer_rep(nn.Module): 259 | def __init__(self, args): 260 | super(Transformer_rep, self).__init__() 261 | self.hidden_size = args.hidden_size 262 | self.heads = 4 263 | self.dropout = args.dropout 264 | self.n_blocks = args.num_blocks 265 | self.transformer_blocks = nn.ModuleList( 266 | [TransformerBlock(self.hidden_size, self.heads, self.dropout) for _ in range(self.n_blocks)]) 267 | 268 | def forward(self, hidden, mask): 269 | for transformer in self.transformer_blocks: 270 | hidden = transformer.forward(hidden, mask) 271 | return hidden 272 | 273 | 274 | class Diffu_xstart(nn.Module): 275 | def __init__(self, hidden_size, args): 276 | super(Diffu_xstart, self).__init__() 277 | self.hidden_size = hidden_size 278 | self.linear_item = nn.Linear(self.hidden_size, self.hidden_size) 279 | self.linear_xt = nn.Linear(self.hidden_size, self.hidden_size) 280 | self.linear_t = nn.Linear(self.hidden_size, self.hidden_size) 281 | time_embed_dim = self.hidden_size * 4 282 | self.time_embed = nn.Sequential(nn.Linear(self.hidden_size, time_embed_dim), SiLU(), nn.Linear(time_embed_dim, self.hidden_size)) 283 | self.fuse_linear = nn.Linear(self.hidden_size*3, self.hidden_size) 284 | self.att = Transformer_rep(args) 285 | # self.mlp_model = nn.Linear(self.hidden_size, self.hidden_size) 286 | # self.gru_model = nn.GRU(self.hidden_size, self.hidden_size, batch_first=True) 287 | # self.gru_model = nn.GRU(self.hidden_size, self.hidden_size, num_layers=args.num_blocks, batch_first=True) 288 | self.lambda_uncertainty = args.lambda_uncertainty 289 | self.dropout = nn.Dropout(args.dropout) 290 | self.norm_diffu_rep = LayerNorm(self.hidden_size) 291 | 292 | def timestep_embedding(self, timesteps, dim, max_period=10000): 293 | """ 294 | Create sinusoidal timestep embeddings. 295 | 296 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 297 | These may be fractional. 298 | :param dim: the dimension of the output. 299 | :param max_period: controls the minimum frequency of the embeddings. 300 | :return: an [N x dim] Tensor of positional embeddings. 301 | """ 302 | half = dim // 2 303 | freqs = th.exp(-math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half).to(device=timesteps.device) 304 | args = timesteps[:, None].float() * freqs[None] 305 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) 306 | if dim % 2: 307 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) 308 | return embedding 309 | 310 | def forward(self, rep_item, x_t, t, mask_seq): 311 | emb_t = self.time_embed(self.timestep_embedding(t, self.hidden_size)) 312 | x_t = x_t + emb_t 313 | 314 | # lambda_uncertainty = th.normal(mean=th.full(rep_item.shape, 1.0), std=th.full(rep_item.shape, 1.0)).to(x_t.device) 315 | 316 | lambda_uncertainty = th.normal(mean=th.full(rep_item.shape, self.lambda_uncertainty), std=th.full(rep_item.shape, self.lambda_uncertainty)).to(x_t.device) ## distribution 317 | # lambda_uncertainty = self.lambda_uncertainty ### fixed 318 | 319 | #### Attention 320 | rep_diffu = self.att(rep_item + lambda_uncertainty * x_t.unsqueeze(1), mask_seq) 321 | rep_diffu = self.norm_diffu_rep(self.dropout(rep_diffu)) 322 | out = rep_diffu[:, -1, :] 323 | 324 | 325 | ## rep_diffu = self.att(rep_item, mask_seq) ## do not use 326 | ## rep_diffu = self.dropout(self.norm_diffu_rep(rep_diffu)) ## do not use 327 | 328 | #### 329 | 330 | #### GRU 331 | # output, hn = self.gru_model(rep_item + lambda_uncertainty * x_t.unsqueeze(1)) 332 | # output = self.norm_diffu_rep(self.dropout(output)) 333 | # out = output[:,-1,:] 334 | ## # out = hn.squeeze(0) 335 | # rep_diffu = None 336 | #### 337 | 338 | ### MLP 339 | # output = self.mlp_model(rep_item + lambda_uncertainty * x_t.unsqueeze(1)) 340 | # output = self.norm_diffu_rep(self.dropout(output)) 341 | # out = output[:,-1,:] 342 | # rep_diffu = None 343 | ### 344 | 345 | # out = out + self.lambda_uncertainty * x_t 346 | 347 | return out, rep_diffu 348 | 349 | 350 | class DiffuRec(nn.Module): 351 | def __init__(self, args,): 352 | super(DiffuRec, self).__init__() 353 | self.hidden_size = args.hidden_size 354 | self.schedule_sampler_name = args.schedule_sampler_name 355 | self.diffusion_steps = args.diffusion_steps 356 | self.use_timesteps = space_timesteps(self.diffusion_steps, [self.diffusion_steps]) 357 | 358 | self.noise_schedule = args.noise_schedule 359 | betas = self.get_betas(self.noise_schedule, self.diffusion_steps) 360 | # Use float64 for accuracy. 361 | betas = np.array(betas, dtype=np.float64) 362 | self.betas = betas 363 | assert len(betas.shape) == 1, "betas must be 1-D" 364 | assert (betas > 0).all() and (betas <= 1).all() 365 | alphas = 1.0 - betas 366 | self.alphas_cumprod = np.cumprod(alphas, axis=0) 367 | self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) 368 | 369 | # calculations for diffusion q(x_t | x_{t-1}) and others 370 | self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) 371 | self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) 372 | 373 | # calculations for diffusion q(x_t | x_{t-1}) and others 374 | self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) 375 | self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) 376 | self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) 377 | self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) 378 | self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) 379 | 380 | self.posterior_mean_coef1 = (betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)) 381 | self.posterior_mean_coef2 = ((1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)) 382 | 383 | # calculations for posterior q(x_{t-1} | x_t, x_0) 384 | self.posterior_variance = (betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)) 385 | 386 | self.num_timesteps = int(self.betas.shape[0]) 387 | 388 | self.schedule_sampler = create_named_schedule_sampler(self.schedule_sampler_name, self.num_timesteps) ## lossaware (schedule_sample) 389 | self.timestep_map = self.time_map() 390 | self.rescale_timesteps = args.rescale_timesteps 391 | self.original_num_steps = len(betas) 392 | 393 | self.xstart_model = Diffu_xstart(self.hidden_size, args) 394 | 395 | def get_betas(self, noise_schedule, diffusion_steps): 396 | betas = get_named_beta_schedule(noise_schedule, diffusion_steps) ## array, generate beta 397 | return betas 398 | 399 | 400 | def q_sample(self, x_start, t, noise=None, mask=None): 401 | """ 402 | Diffuse the data for a given number of diffusion steps. 403 | 404 | In other words, sample from q(x_t | x_0). 405 | 406 | :param x_start: the initial data batch. 407 | :param t: the number of diffusion steps (minus 1). Here, 0 means one step. 408 | :param noise: if specified, the split-out normal noise. 409 | :param mask: anchoring masked position 410 | :return: A noisy version of x_start. 411 | """ 412 | if noise is None: 413 | noise = th.randn_like(x_start) 414 | 415 | assert noise.shape == x_start.shape 416 | x_t = ( 417 | _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start 418 | + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) 419 | * noise ## reparameter trick 420 | ) ## genetrate x_t based on x_0 (x_start) with reparameter trick 421 | 422 | if mask == None: 423 | return x_t 424 | else: 425 | mask = th.broadcast_to(mask.unsqueeze(dim=-1), x_start.shape) ## mask: [0,0,0,1,1,1,1,1] 426 | return th.where(mask==0, x_start, x_t) ## replace the output_target_seq embedding (x_0) as x_t 427 | 428 | def time_map(self): 429 | timestep_map = [] 430 | for i in range(len(self.alphas_cumprod)): 431 | if i in self.use_timesteps: 432 | timestep_map.append(i) 433 | return timestep_map 434 | 435 | # def scale_t(self, ts): 436 | # map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) 437 | # new_ts = map_tensor[ts] 438 | # # print(new_ts) 439 | # if self.rescale_timesteps: 440 | # new_ts = new_ts.float() * (1000.0 / self.original_num_steps) 441 | # return new_ts 442 | 443 | def _scale_timesteps(self, t): 444 | if self.rescale_timesteps: 445 | return t.float() * (1000.0 / self.num_timesteps) 446 | return t 447 | 448 | def _predict_xstart_from_eps(self, x_t, t, eps): 449 | 450 | assert x_t.shape == eps.shape 451 | return ( 452 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t 453 | - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps 454 | ) 455 | 456 | def q_posterior_mean_variance(self, x_start, x_t, t): 457 | """ 458 | Compute the mean and variance of the diffusion posterior: 459 | q(x_{t-1} | x_t, x_0) 460 | 461 | """ 462 | assert x_start.shape == x_t.shape 463 | posterior_mean = ( 464 | _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start 465 | + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t 466 | ) ## \mu_t 467 | assert (posterior_mean.shape[0] == x_start.shape[0]) 468 | return posterior_mean 469 | 470 | def p_mean_variance(self, rep_item, x_t, t, mask_seq): 471 | model_output, _ = self.xstart_model(rep_item, x_t, self._scale_timesteps(t), mask_seq) 472 | 473 | x_0 = model_output ##output predict 474 | # x_0 = self._predict_xstart_from_eps(x_t, t, model_output) ## eps predict 475 | 476 | model_log_variance = np.log(np.append(self.posterior_variance[1], self.betas[1:])) 477 | model_log_variance = _extract_into_tensor(model_log_variance, t, x_t.shape) 478 | 479 | model_mean = self.q_posterior_mean_variance(x_start=x_0, x_t=x_t, t=t) ## x_start: candidante item embedding, x_t: inputseq_embedding + outseq_noise, output x_(t-1) distribution 480 | return model_mean, model_log_variance 481 | 482 | def p_sample(self, item_rep, noise_x_t, t, mask_seq): 483 | model_mean, model_log_variance = self.p_mean_variance(item_rep, noise_x_t, t, mask_seq) 484 | noise = th.randn_like(noise_x_t) 485 | nonzero_mask = ((t != 0).float().view(-1, *([1] * (len(noise_x_t.shape) - 1)))) # no noise when t == 0 486 | sample_xt = model_mean + nonzero_mask * th.exp(0.5 * model_log_variance) * noise ## sample x_{t-1} from the \mu(x_{t-1}) distribution based on the reparameter trick 487 | return sample_xt 488 | 489 | def reverse_p_sample(self, item_rep, noise_x_t, mask_seq): 490 | device = next(self.xstart_model.parameters()).device 491 | indices = list(range(self.num_timesteps))[::-1] 492 | 493 | for i in indices: # from T to 0, reversion iteration 494 | t = th.tensor([i] * item_rep.shape[0], device=device) 495 | with th.no_grad(): 496 | noise_x_t = self.p_sample(item_rep, noise_x_t, t, mask_seq) 497 | return noise_x_t 498 | 499 | def forward(self, item_rep, item_tag, mask_seq): 500 | noise = th.randn_like(item_tag) 501 | t, weights = self.schedule_sampler.sample(item_rep.shape[0], item_tag.device) ## t is sampled from schedule_sampler 502 | 503 | # t = self.scale_t(t) 504 | x_t = self.q_sample(item_tag, t, noise=noise) 505 | 506 | # eps, item_rep_out = self.xstart_model(item_rep, x_t, self._scale_timesteps(t), mask_seq) ## eps predict 507 | # x_0 = self._predict_xstart_from_eps(x_t, t, eps) 508 | 509 | x_0, item_rep_out = self.xstart_model(item_rep, x_t, self._scale_timesteps(t), mask_seq) ##output predict 510 | return x_0, item_rep_out, weights, t 511 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import argparse 4 | import torch 5 | import torch.backends.cudnn as cudnn 6 | import numpy as np 7 | import logging 8 | import time 9 | import pickle 10 | from utils import Data_Train, Data_Val, Data_Test, Data_CHLS 11 | from model import create_model_diffu, Att_Diffuse_model 12 | from trainer import model_train, LSHT_inference 13 | from collections import Counter 14 | 15 | os.environ["CUDA_VISIBLE_DEVICES"] = "2" 16 | 17 | 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--dataset', default='amazon_beauty', help='Dataset name: toys, amazon_beauty, steam, ml-1m') 20 | parser.add_argument('--log_file', default='log/', help='log dir path') 21 | parser.add_argument('--random_seed', type=int, default=1997, help='Random seed') 22 | parser.add_argument('--max_len', type=int, default=50, help='The max length of sequence') 23 | parser.add_argument('--device', type=str, default='cuda', choices=['cpu', 'cuda']) 24 | parser.add_argument('--num_gpu', type=int, default=1, help='Number of GPU') 25 | parser.add_argument('--batch_size', type=int, default=512, help='Batch Size') 26 | parser.add_argument("--hidden_size", default=128, type=int, help="hidden size of model") 27 | parser.add_argument('--dropout', type=float, default=0.1, help='Dropout of representation') 28 | parser.add_argument('--emb_dropout', type=float, default=0.3, help='Dropout of item embedding') 29 | parser.add_argument("--hidden_act", default="gelu", type=str) # gelu relu 30 | parser.add_argument('--num_blocks', type=int, default=4, help='Number of Transformer blocks') 31 | parser.add_argument('--epochs', type=int, default=500, help='Number of epochs for training') ## 500 32 | parser.add_argument('--decay_step', type=int, default=100, help='Decay step for StepLR') 33 | parser.add_argument('--gamma', type=float, default=0.1, help='Gamma for StepLR') 34 | parser.add_argument('--metric_ks', nargs='+', type=int, default=[5, 10, 20], help='ks for Metric@k') 35 | parser.add_argument('--optimizer', type=str, default='Adam', choices=['SGD', 'Adam']) 36 | parser.add_argument('--lr', type=float, default=0.001, help='Learning rate') 37 | parser.add_argument('--loss_lambda', type=float, default=0.001, help='loss weight for diffusion') 38 | parser.add_argument('--weight_decay', type=float, default=0, help='L2 regularization') 39 | parser.add_argument('--momentum', type=float, default=None, help='SGD momentum') 40 | parser.add_argument('--schedule_sampler_name', type=str, default='lossaware', help='Diffusion for t generation') 41 | parser.add_argument('--diffusion_steps', type=int, default=32, help='Diffusion step') 42 | parser.add_argument('--lambda_uncertainty', type=float, default=0.001, help='uncertainty weight') 43 | parser.add_argument('--noise_schedule', default='trunc_lin', help='Beta generation') ## cosine, linear, trunc_cos, trunc_lin, pw_lin, sqrt 44 | parser.add_argument('--rescale_timesteps', default=True, help='rescal timesteps') 45 | parser.add_argument('--eval_interval', type=int, default=20, help='the number of epoch to eval') 46 | parser.add_argument('--patience', type=int, default=5, help='the number of epoch to wait before early stop') 47 | parser.add_argument('--description', type=str, default='Diffu_norm_score', help='Model brief introduction') 48 | parser.add_argument('--long_head', default=False, help='Long and short sequence, head and long-tail items') 49 | parser.add_argument('--diversity_measure', default=False, help='Measure the diversity of recommendation results') 50 | parser.add_argument('--epoch_time_avg', default=False, help='Calculate the average time of one epoch training') 51 | args = parser.parse_args() 52 | 53 | print(args) 54 | 55 | if not os.path.exists(args.log_file): 56 | os.makedirs(args.log_file) 57 | if not os.path.exists(args.log_file + args.dataset): 58 | os.makedirs(args.log_file + args.dataset ) 59 | 60 | 61 | logging.basicConfig(level=logging.INFO, filename=args.log_file + args.dataset + '/' + time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime()) + '.log', 62 | datefmt='%Y/%m/%d %H:%M:%S', format='%(asctime)s - %(name)s - %(levelname)s - %(lineno)d - %(module)s - %(message)s', filemode='w') 63 | logger = logging.getLogger(__name__) 64 | logger.info(args) 65 | 66 | 67 | def fix_random_seed_as(random_seed): 68 | random.seed(random_seed) 69 | torch.manual_seed(random_seed) 70 | torch.cuda.manual_seed_all(random_seed) 71 | np.random.seed(random_seed) 72 | cudnn.deterministic = True 73 | cudnn.benchmark = False 74 | 75 | 76 | def item_num_create(args, item_num): 77 | args.item_num = item_num 78 | return args 79 | 80 | 81 | def cold_hot_long_short(data_raw, dataset_name): 82 | item_list = [] 83 | len_list = [] 84 | target_item = [] 85 | 86 | for id_temp in data_raw['train']: 87 | temp_list = data_raw['train'][id_temp] + data_raw['val'][id_temp] + data_raw['test'][id_temp] 88 | len_list.append(len(temp_list)) 89 | target_item.append(data_raw['test'][id_temp][0]) 90 | item_list += temp_list 91 | item_num_count = Counter(item_list) 92 | split_num = np.percentile(list(item_num_count.values()), 80) 93 | cold_item, hot_item = [], [] 94 | for item_num_temp in item_num_count.items(): 95 | if item_num_temp[1] < split_num: 96 | cold_item.append(item_num_temp[0]) 97 | else: 98 | hot_item.append(item_num_temp[0]) 99 | cold_ids, hot_ids = [], [] 100 | cold_list, hot_list = [], [] 101 | for id_temp, item_temp in enumerate(data_raw['test'].values()): 102 | if item_temp[0] in hot_item: 103 | hot_ids.append(id_temp) 104 | if dataset_name == 'ml-1m': 105 | hot_list.append(data_raw['train'][id_temp+1] + data_raw['val'][id_temp+1] + data_raw['test'][id_temp+1]) 106 | else: 107 | hot_list.append(data_raw['train'][id_temp] + data_raw['val'][id_temp] + data_raw['test'][id_temp]) 108 | else: 109 | cold_ids.append(id_temp) 110 | if dataset_name == 'ml-1m': 111 | cold_list.append(data_raw['train'][id_temp+1] + data_raw['val'][id_temp+1] + data_raw['test'][id_temp+1]) 112 | else: 113 | cold_list.append(data_raw['train'][id_temp] + data_raw['val'][id_temp] + data_raw['test'][id_temp]) 114 | cold_hot_dict = {'hot': hot_list, 'cold': cold_list} 115 | 116 | len_short = np.percentile(len_list, 20) 117 | len_midshort = np.percentile(len_list, 40) 118 | len_midlong = np.percentile(len_list, 60) 119 | len_long = np.percentile(len_list, 80) 120 | 121 | len_seq_dict = {'short': [], 'mid_short': [], 'mid': [], 'mid_long': [], 'long': []} 122 | for id_temp, len_temp in enumerate(len_list): 123 | if dataset_name == 'ml-1m': 124 | temp_seq = data_raw['train'][id_temp+1] + data_raw['val'][id_temp+1] + data_raw['test'][id_temp+1] 125 | else: 126 | temp_seq = data_raw['train'][id_temp] + data_raw['val'][id_temp] + data_raw['test'][id_temp] 127 | if len_temp <= len_short: 128 | len_seq_dict['short'].append(temp_seq) 129 | elif len_short < len_temp <= len_midshort: 130 | len_seq_dict['mid_short'].append(temp_seq) 131 | elif len_midshort < len_temp <= len_midlong: 132 | len_seq_dict['mid'].append(temp_seq) 133 | elif len_midlong < len_temp <= len_long: 134 | len_seq_dict['mid_long'].append(temp_seq) 135 | else: 136 | len_seq_dict['long'].append(temp_seq) 137 | return cold_hot_dict, len_seq_dict, split_num, [len_short, len_midshort, len_midlong, len_long], len_list, list(item_num_count.values()) 138 | 139 | 140 | def main(args): 141 | fix_random_seed_as(args.random_seed) 142 | path_data = '../datasets/data/' + args.dataset + '/dataset.pkl' 143 | with open(path_data, 'rb') as f: 144 | data_raw = pickle.load(f) 145 | 146 | # cold_hot_long_short(data_raw, args.dataset) 147 | 148 | args = item_num_create(args, len(data_raw['smap'])) 149 | tra_data = Data_Train(data_raw['train'], args) 150 | val_data = Data_Val(data_raw['train'], data_raw['val'], args) 151 | test_data = Data_Test(data_raw['train'], data_raw['val'], data_raw['test'], args) 152 | tra_data_loader = tra_data.get_pytorch_dataloaders() 153 | val_data_loader = val_data.get_pytorch_dataloaders() 154 | test_data_loader = test_data.get_pytorch_dataloaders() 155 | diffu_rec = create_model_diffu(args) 156 | rec_diffu_joint_model = Att_Diffuse_model(diffu_rec, args) 157 | 158 | best_model, test_results = model_train(tra_data_loader, val_data_loader, test_data_loader, rec_diffu_joint_model, args, logger) 159 | 160 | 161 | if args.long_head: 162 | cold_hot_dict, len_seq_dict, split_hotcold, split_length, list_len, list_num = cold_hot_long_short(data_raw, args.dataset) 163 | cold_data = Data_CHLS(cold_hot_dict['cold'], args) 164 | cold_data_loader = cold_data.get_pytorch_dataloaders() 165 | print('--------------Cold item-----------------------') 166 | LSHT_inference(best_model, args, cold_data_loader) 167 | 168 | hot_data = Data_CHLS(cold_hot_dict['hot'], args) 169 | hot_data_loader = hot_data.get_pytorch_dataloaders() 170 | print('--------------hot item-----------------------') 171 | LSHT_inference(best_model, args, hot_data_loader) 172 | 173 | short_data = Data_CHLS(len_seq_dict['short'], args) 174 | short_data_loader = short_data.get_pytorch_dataloaders() 175 | print('--------------Short-----------------------') 176 | LSHT_inference(best_model, args, short_data_loader) 177 | 178 | mid_short_data = Data_CHLS(len_seq_dict['mid_short'], args) 179 | mid_short_data_loader = mid_short_data.get_pytorch_dataloaders() 180 | print('--------------Mid_short-----------------------') 181 | LSHT_inference(best_model, args, mid_short_data_loader) 182 | 183 | mid_data = Data_CHLS(len_seq_dict['mid'], args) 184 | mid_data_loader = mid_data.get_pytorch_dataloaders() 185 | print('--------------Mid-----------------------') 186 | LSHT_inference(best_model, args, mid_data_loader) 187 | 188 | mid_long_data = Data_CHLS(len_seq_dict['mid_long'], args) 189 | mid_long_data_loader = mid_long_data.get_pytorch_dataloaders() 190 | print('--------------Mid_long-----------------------') 191 | LSHT_inference(best_model, args, mid_long_data_loader) 192 | 193 | long_data = Data_CHLS(len_seq_dict['long'], args) 194 | long_data_loader = long_data.get_pytorch_dataloaders() 195 | print('--------------Long-----------------------') 196 | LSHT_inference(best_model, args, long_data_loader) 197 | 198 | 199 | if __name__ == '__main__': 200 | main(args) 201 | -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import math 4 | from diffurec import DiffuRec 5 | import torch.nn.functional as F 6 | import copy 7 | import numpy as np 8 | from step_sample import LossAwareSampler 9 | import torch as th 10 | 11 | 12 | class LayerNorm(nn.Module): 13 | def __init__(self, hidden_size, eps=1e-12): 14 | """Construct a layernorm module in the TF style (epsilon inside the square root). 15 | """ 16 | super(LayerNorm, self).__init__() 17 | self.weight = nn.Parameter(torch.ones(hidden_size)) 18 | self.bias = nn.Parameter(torch.zeros(hidden_size)) 19 | self.variance_epsilon = eps 20 | 21 | def forward(self, x): 22 | u = x.mean(-1, keepdim=True) 23 | s = (x - u).pow(2).mean(-1, keepdim=True) 24 | x = (x - u) / torch.sqrt(s + self.variance_epsilon) 25 | return self.weight * x + self.bias 26 | 27 | 28 | class Att_Diffuse_model(nn.Module): 29 | def __init__(self, diffu, args): 30 | super(Att_Diffuse_model, self).__init__() 31 | self.emb_dim = args.hidden_size 32 | self.item_num = args.item_num+1 33 | self.item_embeddings = nn.Embedding(self.item_num, self.emb_dim) 34 | self.embed_dropout = nn.Dropout(args.emb_dropout) 35 | self.position_embeddings = nn.Embedding(args.max_len, args.hidden_size) 36 | self.LayerNorm = LayerNorm(args.hidden_size, eps=1e-12) 37 | self.dropout = nn.Dropout(args.dropout) 38 | self.diffu = diffu 39 | self.loss_ce = nn.CrossEntropyLoss() 40 | self.loss_ce_rec = nn.CrossEntropyLoss(reduction='none') 41 | self.loss_mse = nn.MSELoss() 42 | 43 | def diffu_pre(self, item_rep, tag_emb, mask_seq): 44 | seq_rep_diffu, item_rep_out, weights, t = self.diffu(item_rep, tag_emb, mask_seq) 45 | return seq_rep_diffu, item_rep_out, weights, t 46 | 47 | def reverse(self, item_rep, noise_x_t, mask_seq): 48 | reverse_pre = self.diffu.reverse_p_sample(item_rep, noise_x_t, mask_seq) 49 | return reverse_pre 50 | 51 | def loss_rec(self, scores, labels): 52 | return self.loss_ce(scores, labels.squeeze(-1)) 53 | 54 | def loss_diffu(self, rep_diffu, labels): 55 | scores = torch.matmul(rep_diffu, self.item_embeddings.weight.t()) 56 | scores_pos = scores.gather(1 , labels) ## labels: b x 1 57 | scores_neg_mean = (torch.sum(scores, dim=-1).unsqueeze(-1)-scores_pos)/(scores.shape[1]-1) 58 | 59 | loss = torch.min(-torch.log(torch.mean(torch.sigmoid((scores_pos - scores_neg_mean).squeeze(-1)))), torch.tensor(1e8)) 60 | 61 | # if isinstance(self.diffu.schedule_sampler, LossAwareSampler): 62 | # self.diffu.schedule_sampler.update_with_all_losses(t, loss.detach()) 63 | # loss = (loss * weights).mean() 64 | return loss 65 | 66 | def loss_diffu_ce(self, rep_diffu, labels): 67 | scores = torch.matmul(rep_diffu, self.item_embeddings.weight.t()) 68 | """ 69 | ### norm scores 70 | item_emb_norm = F.normalize(self.item_embeddings.weight, dim=-1) 71 | rep_diffu_norm = F.normalize(rep_diffu, dim=-1) 72 | temperature = 0.07 73 | scores = torch.matmul(rep_diffu_norm, item_emb_norm.t())/temperature 74 | """ 75 | return self.loss_ce(scores, labels.squeeze(-1)) 76 | 77 | def diffu_rep_pre(self, rep_diffu): 78 | scores = torch.matmul(rep_diffu, self.item_embeddings.weight.t()) 79 | return scores 80 | 81 | def loss_rmse(self, rep_diffu, labels): 82 | rep_gt = self.item_embeddings(labels).squeeze(1) 83 | return torch.sqrt(self.loss_mse(rep_gt, rep_diffu)) 84 | 85 | def routing_rep_pre(self, rep_diffu): 86 | item_norm = (self.item_embeddings.weight**2).sum(-1).view(-1, 1) ## N x 1 87 | rep_norm = (rep_diffu**2).sum(-1).view(-1, 1) ## B x 1 88 | sim = torch.matmul(rep_diffu, self.item_embeddings.weight.t()) ## B x N 89 | dist = rep_norm + item_norm.transpose(0, 1) - 2.0 * sim 90 | dist = torch.clamp(dist, 0.0, np.inf) 91 | 92 | return -dist 93 | 94 | def regularization_rep(self, seq_rep, mask_seq): 95 | seqs_norm = seq_rep/seq_rep.norm(dim=-1)[:, :, None] 96 | seqs_norm = seqs_norm * mask_seq.unsqueeze(-1) 97 | cos_mat = torch.matmul(seqs_norm, seqs_norm.transpose(1, 2)) 98 | cos_sim = torch.mean(torch.mean(torch.sum(torch.sigmoid(-cos_mat), dim=-1), dim=-1), dim=-1) ## not real mean 99 | return cos_sim 100 | 101 | def regularization_seq_item_rep(self, seq_rep, item_rep, mask_seq): 102 | item_norm = item_rep/item_rep.norm(dim=-1)[:, :, None] 103 | item_norm = item_norm * mask_seq.unsqueeze(-1) 104 | 105 | seq_rep_norm = seq_rep/seq_rep.norm(dim=-1)[:, None] 106 | sim_mat = torch.sigmoid(-torch.matmul(item_norm, seq_rep_norm.unsqueeze(-1)).squeeze(-1)) 107 | return torch.mean(torch.sum(sim_mat, dim=-1)/torch.sum(mask_seq, dim=-1)) 108 | 109 | def forward(self, sequence, tag, train_flag=True): 110 | seq_length = sequence.size(1) 111 | # position_ids = torch.arange(seq_length, dtype=torch.long, device=sequence.device) 112 | # position_ids = position_ids.unsqueeze(0).expand_as(sequence) 113 | # position_embeddings = self.position_embeddings(position_ids) 114 | 115 | item_embeddings = self.item_embeddings(sequence) 116 | item_embeddings = self.embed_dropout(item_embeddings) ## dropout first than layernorm 117 | 118 | # item_embeddings = item_embeddings + position_embeddings 119 | 120 | item_embeddings = self.LayerNorm(item_embeddings) 121 | 122 | mask_seq = (sequence>0).float() 123 | 124 | if train_flag: 125 | tag_emb = self.item_embeddings(tag.squeeze(-1)) ## B x H 126 | rep_diffu, rep_item, weights, t = self.diffu_pre(item_embeddings, tag_emb, mask_seq) 127 | 128 | # item_rep_dis = self.regularization_rep(rep_item, mask_seq) 129 | # seq_rep_dis = self.regularization_seq_item_rep(rep_diffu, rep_item, mask_seq) 130 | 131 | item_rep_dis = None 132 | seq_rep_dis = None 133 | else: 134 | # noise_x_t = th.randn_like(tag_emb) 135 | noise_x_t = th.randn_like(item_embeddings[:,-1,:]) 136 | rep_diffu = self.reverse(item_embeddings, noise_x_t, mask_seq) 137 | weights, t, item_rep_dis, seq_rep_dis = None, None, None, None 138 | 139 | # item_rep = self.model_main(item_embeddings, rep_diffu, mask_seq) 140 | # seq_rep = item_rep[:, -1, :] 141 | # scores = torch.matmul(seq_rep, self.item_embeddings.weight.t()) 142 | scores = None 143 | return scores, rep_diffu, weights, t, item_rep_dis, seq_rep_dis 144 | 145 | 146 | def create_model_diffu(args): 147 | diffu_pre = DiffuRec(args) 148 | return diffu_pre 149 | -------------------------------------------------------------------------------- /src/step_sample.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from abc import ABC, abstractmethod 3 | import torch as th 4 | import torch.distributed as dist 5 | 6 | 7 | class ScheduleSampler(ABC): 8 | """ 9 | A distribution over timesteps in the diffusion process, intended to reduce 10 | variance of the objective. 11 | 12 | By default, samplers perform unbiased importance sampling, in which the 13 | objective's mean is unchanged. 14 | However, subclasses may override sample() to change how the resampled 15 | terms are reweighted, allowing for actual changes in the objective. 16 | """ 17 | 18 | @abstractmethod 19 | def weights(self): 20 | """ 21 | Get a numpy array of weights, one per diffusion step. 22 | The weights needn't be normalized, but must be positive. 23 | """ 24 | 25 | def sample(self, batch_size, device): 26 | """ 27 | Importance-sample timesteps for a batch. 28 | 29 | :param batch_size: the number of timesteps. 30 | :param device: the torch device to save to. 31 | :return: a tuple (timesteps, weights): 32 | - timesteps: a tensor of timestep indices. 33 | - weights: a tensor of weights to scale the resulting losses. 34 | """ 35 | w = self.weights() 36 | p = w / np.sum(w) 37 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p) 38 | indices = th.from_numpy(indices_np).long().to(device) 39 | weights_np = 1 / (len(p) * p[indices_np]) 40 | weights = th.from_numpy(weights_np).float().to(device) 41 | return indices, weights 42 | 43 | 44 | class UniformSampler(ScheduleSampler): 45 | def __init__(self, num_timesteps): 46 | self.num_timesteps = num_timesteps 47 | self._weights = np.ones([self.num_timesteps]) 48 | 49 | def weights(self): 50 | return self._weights 51 | 52 | 53 | class LossAwareSampler(ScheduleSampler): 54 | def update_with_local_losses(self, local_ts, local_losses): 55 | """ 56 | Update the reweighting using losses from a model. 57 | 58 | Call this method from each rank with a batch of timesteps and the 59 | corresponding losses for each of those timesteps. 60 | This method will perform synchronization to make sure all of the ranks 61 | maintain the exact same reweighting. 62 | 63 | :param local_ts: an integer Tensor of timesteps. 64 | :param local_losses: a 1D Tensor of losses. 65 | """ 66 | batch_sizes = [ 67 | th.tensor([0], dtype=th.int32, device=local_ts.device) 68 | for _ in range(dist.get_world_size()) 69 | ] 70 | dist.all_gather( 71 | batch_sizes, 72 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), 73 | ) 74 | 75 | # Pad all_gather batches to be the maximum batch size. 76 | batch_sizes = [x.item() for x in batch_sizes] 77 | max_bs = max(batch_sizes) 78 | 79 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] 80 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] 81 | dist.all_gather(timestep_batches, local_ts) 82 | dist.all_gather(loss_batches, local_losses) 83 | timesteps = [ 84 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] 85 | ] 86 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] 87 | self.update_with_all_losses(timesteps, losses) 88 | 89 | @abstractmethod 90 | def update_with_all_losses(self, ts, losses): 91 | """ 92 | Update the reweighting using losses from a model. 93 | 94 | Sub-classes should override this method to update the reweighting 95 | using losses from the model. 96 | 97 | This method directly updates the reweighting without synchronizing 98 | between workers. It is called by update_with_local_losses from all 99 | ranks with identical arguments. Thus, it should have deterministic 100 | behavior to maintain state across workers. 101 | 102 | :param ts: a list of int timesteps. 103 | :param losses: a list of float losses, one per timestep. 104 | """ 105 | 106 | 107 | class LossSecondMomentResampler(LossAwareSampler): 108 | def __init__(self, num_timesteps, history_per_term=10, uniform_prob=0.001): 109 | self.num_timesteps = num_timesteps 110 | self.history_per_term = history_per_term 111 | self.uniform_prob = uniform_prob 112 | self._loss_history = np.zeros( 113 | [self.num_timesteps, history_per_term], dtype=np.float64 114 | ) 115 | self._loss_counts = np.zeros([self.num_timesteps], dtype=np.int) 116 | 117 | def weights(self): 118 | if not self._warmed_up(): 119 | return np.ones([self.num_timesteps], dtype=np.float64) 120 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) 121 | weights /= np.sum(weights) 122 | weights *= 1 - self.uniform_prob 123 | weights += self.uniform_prob / len(weights) 124 | return weights 125 | 126 | def update_with_all_losses(self, ts, losses): 127 | for t, loss in zip(ts, losses): 128 | if self._loss_counts[t] == self.history_per_term: 129 | # Shift out the oldest loss term. 130 | self._loss_history[t, :-1] = self._loss_history[t, 1:] 131 | self._loss_history[t, -1] = loss 132 | else: 133 | self._loss_history[t, self._loss_counts[t]] = loss 134 | self._loss_counts[t] += 1 135 | 136 | def _warmed_up(self): 137 | return (self._loss_counts == self.history_per_term).all() 138 | 139 | 140 | class FixSampler(ScheduleSampler): 141 | def __init__(self, num_timesteps): 142 | self.num_timesteps = num_timesteps 143 | ############################################################### 144 | ### You can custome your own sampling weight of steps here. ### 145 | ############################################################### 146 | self._weights = np.concatenate([np.ones([num_timesteps//2]), np.zeros([num_timesteps//2]) + 0.5]) 147 | 148 | def weights(self): 149 | return self._weights 150 | 151 | 152 | def create_named_schedule_sampler(name, num_timesteps): 153 | """ 154 | Create a ScheduleSampler from a library of pre-defined samplers. 155 | :param name: the name of the sampler. 156 | :param diffusion: the diffusion object to sample for. 157 | """ 158 | if name == "uniform": 159 | return UniformSampler(num_timesteps) 160 | elif name == "lossaware": 161 | return LossSecondMomentResampler(num_timesteps) ## default setting 162 | elif name == "fixstep": 163 | return FixSampler(num_timesteps) 164 | else: 165 | raise NotImplementedError(f"unknown schedule sampler: {name}") 166 | -------------------------------------------------------------------------------- /src/trainer.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.optim as optim 3 | import datetime 4 | import torch 5 | import numpy as np 6 | import copy 7 | import time 8 | import pickle 9 | 10 | 11 | def optimizers(model, args): 12 | if args.optimizer.lower() == 'adam': 13 | return optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 14 | elif args.optimizer.lower() == 'sgd': 15 | return optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, momentum=args.momentum) 16 | else: 17 | raise ValueError 18 | 19 | 20 | def cal_hr(label, predict, ks): 21 | max_ks = max(ks) 22 | _, topk_predict = torch.topk(predict, k=max_ks, dim=-1) 23 | hit = label == topk_predict 24 | hr = [hit[:, :ks[i]].sum().item()/label.size()[0] for i in range(len(ks))] 25 | return hr 26 | 27 | 28 | def cal_ndcg(label, predict, ks): 29 | max_ks = max(ks) 30 | _, topk_predict = torch.topk(predict, k=max_ks, dim=-1) 31 | hit = (label == topk_predict).int() 32 | ndcg = [] 33 | for k in ks: 34 | max_dcg = dcg(torch.tensor([1] + [0] * (k-1))) 35 | predict_dcg = dcg(hit[:, :k]) 36 | ndcg.append((predict_dcg/max_dcg).mean().item()) 37 | return ndcg 38 | 39 | 40 | def dcg(hit): 41 | log2 = torch.log2(torch.arange(1, hit.size()[-1] + 1) + 1).unsqueeze(0) 42 | rel = (hit/log2).sum(dim=-1) 43 | return rel 44 | 45 | 46 | def hrs_and_ndcgs_k(scores, labels, ks): 47 | metrics = {} 48 | ndcg = cal_ndcg(labels.clone().detach().to('cpu'), scores.clone().detach().to('cpu'), ks) 49 | hr = cal_hr(labels.clone().detach().to('cpu'), scores.clone().detach().to('cpu'), ks) 50 | for k, ndcg_temp, hr_temp in zip(ks, ndcg, hr): 51 | metrics['HR@%d' % k] = hr_temp 52 | metrics['NDCG@%d' % k] = ndcg_temp 53 | return metrics 54 | 55 | 56 | def LSHT_inference(model_joint, args, data_loader): 57 | device = args.device 58 | model_joint = model_joint.to(device) 59 | with torch.no_grad(): 60 | test_metrics_dict = {'HR@5': [], 'NDCG@5': [], 'HR@10': [], 'NDCG@10': [], 'HR@20': [], 'NDCG@20': []} 61 | test_metrics_dict_mean = {} 62 | for test_batch in data_loader: 63 | test_batch = [x.to(device) for x in test_batch] 64 | 65 | scores_rec, rep_diffu, _, _, _, _ = model_joint(test_batch[0], test_batch[1], train_flag=False) 66 | scores_rec_diffu = model_joint.diffu_rep_pre(rep_diffu) 67 | metrics = hrs_and_ndcgs_k(scores_rec_diffu, test_batch[1], [5, 10, 20]) 68 | for k, v in metrics.items(): 69 | test_metrics_dict[k].append(v) 70 | for key_temp, values_temp in test_metrics_dict.items(): 71 | values_mean = round(np.mean(values_temp) * 100, 4) 72 | test_metrics_dict_mean[key_temp] = values_mean 73 | print(test_metrics_dict_mean) 74 | 75 | 76 | def model_train(tra_data_loader, val_data_loader, test_data_loader, model_joint, args, logger): 77 | epochs = args.epochs 78 | device = args.device 79 | metric_ks = args.metric_ks 80 | model_joint = model_joint.to(device) 81 | is_parallel = args.num_gpu > 1 82 | if is_parallel: 83 | model_joint = nn.DataParallel(model_joint) 84 | optimizer = optimizers(model_joint, args) 85 | lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.decay_step, gamma=args.gamma) 86 | best_metrics_dict = {'Best_HR@5': 0, 'Best_NDCG@5': 0, 'Best_HR@10': 0, 'Best_NDCG@10': 0, 'Best_HR@20': 0, 'Best_NDCG@20': 0} 87 | best_epoch = {'Best_epoch_HR@5': 0, 'Best_epoch_NDCG@5': 0, 'Best_epoch_HR@10': 0, 'Best_epoch_NDCG@10': 0, 'Best_epoch_HR@20': 0, 'Best_epoch_NDCG@20': 0} 88 | bad_count = 0 89 | 90 | for epoch_temp in range(epochs): 91 | print('Epoch: {}'.format(epoch_temp)) 92 | logger.info('Epoch: {}'.format(epoch_temp)) 93 | model_joint.train() 94 | 95 | flag_update = 0 96 | for index_temp, train_batch in enumerate(tra_data_loader): 97 | train_batch = [x.to(device) for x in train_batch] 98 | optimizer.zero_grad() 99 | scores, diffu_rep, weights, t, item_rep_dis, seq_rep_dis = model_joint(train_batch[0], train_batch[1], train_flag=True) 100 | loss_diffu_value = model_joint.loss_diffu_ce(diffu_rep, train_batch[1]) ## use this not above 101 | 102 | loss_all = loss_diffu_value 103 | loss_all.backward() 104 | 105 | optimizer.step() 106 | if index_temp % int(len(tra_data_loader) / 5 + 1) == 0: 107 | print('[%d/%d] Loss: %.4f' % (index_temp, len(tra_data_loader), loss_all.item())) 108 | logger.info('[%d/%d] Loss: %.4f' % (index_temp, len(tra_data_loader), loss_all.item())) 109 | print("loss in epoch {}: {}".format(epoch_temp, loss_all.item())) 110 | lr_scheduler.step() 111 | 112 | if epoch_temp != 0 and epoch_temp % args.eval_interval == 0: 113 | print('start predicting: ', datetime.datetime.now()) 114 | logger.info('start predicting: {}'.format(datetime.datetime.now())) 115 | model_joint.eval() 116 | with torch.no_grad(): 117 | metrics_dict = {'HR@5': [], 'NDCG@5': [], 'HR@10': [], 'NDCG@10': [], 'HR@20': [], 'NDCG@20': []} 118 | # metrics_dict_mean = {} 119 | for val_batch in val_data_loader: 120 | val_batch = [x.to(device) for x in val_batch] 121 | scores_rec, rep_diffu, _, _, _, _ = model_joint(val_batch[0], val_batch[1], train_flag=False) 122 | scores_rec_diffu = model_joint.diffu_rep_pre(rep_diffu) ### inner_production 123 | # scores_rec_diffu = model_joint.routing_rep_pre(rep_diffu) ### routing_rep_pre 124 | metrics = hrs_and_ndcgs_k(scores_rec_diffu, val_batch[1], metric_ks) 125 | for k, v in metrics.items(): 126 | metrics_dict[k].append(v) 127 | 128 | for key_temp, values_temp in metrics_dict.items(): 129 | values_mean = round(np.mean(values_temp) * 100, 4) 130 | if values_mean > best_metrics_dict['Best_' + key_temp]: 131 | flag_update = 1 132 | bad_count = 0 133 | best_metrics_dict['Best_' + key_temp] = values_mean 134 | best_epoch['Best_epoch_' + key_temp] = epoch_temp 135 | 136 | if flag_update == 0: 137 | bad_count += 1 138 | else: 139 | print(best_metrics_dict) 140 | print(best_epoch) 141 | logger.info(best_metrics_dict) 142 | logger.info(best_epoch) 143 | best_model = copy.deepcopy(model_joint) 144 | if bad_count >= args.patience: 145 | break 146 | 147 | 148 | logger.info(best_metrics_dict) 149 | logger.info(best_epoch) 150 | 151 | if args.eval_interval > epochs: 152 | best_model = copy.deepcopy(model_joint) 153 | 154 | 155 | top_100_item = [] 156 | with torch.no_grad(): 157 | test_metrics_dict = {'HR@5': [], 'NDCG@5': [], 'HR@10': [], 'NDCG@10': [], 'HR@20': [], 'NDCG@20': []} 158 | test_metrics_dict_mean = {} 159 | for test_batch in test_data_loader: 160 | test_batch = [x.to(device) for x in test_batch] 161 | scores_rec, rep_diffu, _, _, _, _ = best_model(test_batch[0], test_batch[1], train_flag=False) 162 | scores_rec_diffu = best_model.diffu_rep_pre(rep_diffu) ### Inner Production 163 | # scores_rec_diffu = best_model.routing_rep_pre(rep_diffu) ### routing 164 | 165 | _, indices = torch.topk(scores_rec_diffu, k=100) 166 | top_100_item.append(indices) 167 | 168 | metrics = hrs_and_ndcgs_k(scores_rec_diffu, test_batch[1], metric_ks) 169 | for k, v in metrics.items(): 170 | test_metrics_dict[k].append(v) 171 | 172 | for key_temp, values_temp in test_metrics_dict.items(): 173 | values_mean = round(np.mean(values_temp) * 100, 4) 174 | test_metrics_dict_mean[key_temp] = values_mean 175 | print('Test------------------------------------------------------') 176 | logger.info('Test------------------------------------------------------') 177 | print(test_metrics_dict_mean) 178 | logger.info(test_metrics_dict_mean) 179 | print('Best Eval---------------------------------------------------------') 180 | logger.info('Best Eval---------------------------------------------------------') 181 | print(best_metrics_dict) 182 | print(best_epoch) 183 | logger.info(best_metrics_dict) 184 | logger.info(best_epoch) 185 | 186 | print(args) 187 | 188 | if args.diversity_measure: 189 | path_data = '../datasets/data/category/' + args.dataset +'/id_category_dict.pkl' 190 | with open(path_data, 'rb') as f: 191 | id_category_dict = pickle.load(f) 192 | id_top_100 = torch.cat(top_100_item, dim=0).tolist() 193 | category_list_100 = [] 194 | for id_top_100_temp in id_top_100: 195 | category_temp_list = [] 196 | for id_temp in id_top_100_temp: 197 | category_temp_list.append(id_category_dict[id_temp]) 198 | category_list_100.append(category_temp_list) 199 | category_list_100.append(category_list_100) 200 | path_data_category = '../datasets/data/category/' + args.dataset +'/DiffuRec_top100_category.pkl' 201 | with open(path_data_category, 'wb') as f: 202 | pickle.dump(category_list_100, f) 203 | 204 | 205 | return best_model, test_metrics_dict_mean 206 | 207 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data_utils 2 | import torch 3 | 4 | 5 | class TrainDataset(data_utils.Dataset): 6 | def __init__(self, id2seq, max_len): 7 | self.id2seq = id2seq 8 | self.max_len = max_len 9 | 10 | def __len__(self): 11 | return len(self.id2seq) 12 | 13 | def __getitem__(self, index): 14 | seq = self._getseq(index) 15 | labels = [seq[-1]] 16 | tokens = seq[:-1] 17 | tokens = tokens[-self.max_len:] 18 | mask_len = self.max_len - len(tokens) 19 | tokens = [0] * mask_len + tokens 20 | return torch.LongTensor(tokens), torch.LongTensor(labels) 21 | 22 | def _getseq(self, idx): 23 | return self.id2seq[idx] 24 | 25 | 26 | class Data_Train(): 27 | def __init__(self, data_train, args): 28 | self.u2seq = data_train 29 | self.max_len = args.max_len 30 | self.batch_size = args.batch_size 31 | self.split_onebyone() 32 | 33 | def split_onebyone(self): 34 | self.id_seq = {} 35 | self.id_seq_user = {} 36 | idx = 0 37 | for user_temp, seq_temp in self.u2seq.items(): 38 | for star in range(len(seq_temp)-1): 39 | self.id_seq[idx] = seq_temp[:star+2] 40 | self.id_seq_user[idx] = user_temp 41 | idx += 1 42 | 43 | def get_pytorch_dataloaders(self): 44 | dataset = TrainDataset(self.id_seq, self.max_len) 45 | return data_utils.DataLoader(dataset, batch_size=self.batch_size, shuffle=True, pin_memory=True) 46 | 47 | 48 | class ValDataset(data_utils.Dataset): 49 | def __init__(self, u2seq, u2answer, max_len): 50 | self.u2seq = u2seq 51 | self.users = sorted(self.u2seq.keys()) 52 | self.u2answer = u2answer 53 | self.max_len = max_len 54 | 55 | def __len__(self): 56 | return len(self.users) 57 | 58 | def __getitem__(self, index): 59 | user = self.users[index] 60 | seq = self.u2seq[user] 61 | answer = self.u2answer[user] 62 | seq = seq[-self.max_len:] 63 | padding_len = self.max_len - len(seq) 64 | seq = [0] * padding_len + seq 65 | return torch.LongTensor(seq), torch.LongTensor(answer) 66 | 67 | 68 | class Data_Val(): 69 | def __init__(self, data_train, data_val, args): 70 | self.batch_size = args.batch_size 71 | self.u2seq = data_train 72 | self.u2answer = data_val 73 | self.max_len = args.max_len 74 | 75 | 76 | def get_pytorch_dataloaders(self): 77 | dataset = ValDataset(self.u2seq, self.u2answer, self.max_len) 78 | dataloader = data_utils.DataLoader(dataset, batch_size=self.batch_size, shuffle=False, pin_memory=True) 79 | return dataloader 80 | 81 | 82 | class TestDataset(data_utils.Dataset): 83 | def __init__(self, u2seq, u2_seq_add, u2answer, max_len): 84 | self.u2seq = u2seq 85 | self.u2seq_add = u2_seq_add 86 | self.users = sorted(self.u2seq.keys()) 87 | self.u2answer = u2answer 88 | self.max_len = max_len 89 | 90 | def __len__(self): 91 | return len(self.users) 92 | 93 | def __getitem__(self, index): 94 | user = self.users[index] 95 | seq = self.u2seq[user] + self.u2seq_add[user] 96 | # seq = self.u2seq[user] 97 | answer = self.u2answer[user] 98 | seq = seq[-self.max_len:] 99 | padding_len = self.max_len - len(seq) 100 | seq = [0] * padding_len + seq 101 | return torch.LongTensor(seq), torch.LongTensor(answer) 102 | 103 | 104 | class Data_Test(): 105 | def __init__(self, data_train, data_val, data_test, args): 106 | self.batch_size = args.batch_size 107 | self.u2seq = data_train 108 | self.u2seq_add = data_val 109 | self.u2answer = data_test 110 | self.max_len = args.max_len 111 | 112 | def get_pytorch_dataloaders(self): 113 | dataset = TestDataset(self.u2seq, self.u2seq_add, self.u2answer, self.max_len) 114 | dataloader = data_utils.DataLoader(dataset, batch_size=self.batch_size, shuffle=False, pin_memory=True) 115 | return dataloader 116 | 117 | 118 | class CHLSDataset(data_utils.Dataset): 119 | def __init__(self, data, max_len): 120 | self.data = data 121 | self.max_len = max_len 122 | 123 | def __len__(self): 124 | return len(self.data) 125 | 126 | def __getitem__(self, index): 127 | 128 | data_temp = self.data[index] 129 | seq = data_temp[:-1] 130 | answer = [data_temp[-1]] 131 | seq = seq[-self.max_len:] 132 | padding_len = self.max_len - len(seq) 133 | seq = [0] * padding_len + seq 134 | return torch.LongTensor(seq), torch.LongTensor(answer) 135 | 136 | 137 | class Data_CHLS(): 138 | def __init__(self, data, args): 139 | self.batch_size = args.batch_size 140 | self.max_len = args.max_len 141 | self.data = data 142 | 143 | def get_pytorch_dataloaders(self): 144 | dataset = CHLSDataset(self.data, self.max_len) 145 | dataloader = data_utils.DataLoader(dataset, batch_size=self.batch_size, shuffle=False, pin_memory=True) 146 | return dataloader 147 | --------------------------------------------------------------------------------