├── NOTICE ├── requierements.txt ├── CODE_OF_CONDUCT.md ├── SimRec ├── scripts │ ├── ML-1M │ │ └── train.sh │ ├── Steam │ │ └── train.sh │ ├── Tools │ │ └── train.sh │ ├── Beauty │ │ └── train.sh │ ├── PetSupplies │ │ └── train.sh │ └── HomeKitchen │ │ └── train.sh ├── model.py ├── utils.py └── main.py ├── README.md ├── CONTRIBUTING.md ├── data_preprocessing ├── ML-1M │ └── preprocessing_data.ipynb ├── Steam │ └── preprocessing_data.ipynb ├── Beauty │ └── preprocessing_data.ipynb ├── PetSupplies │ └── preprocessing_data.ipynb ├── HomeKitchen │ └── preprocessing_data.ipynb ├── Tools │ └── preprocessing_data.ipynb └── calculate_similarity_scores.ipynb └── LICENSE /NOTICE: -------------------------------------------------------------------------------- 1 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | -------------------------------------------------------------------------------- /requierements.txt: -------------------------------------------------------------------------------- 1 | pytorch 2 | numpy 3 | tqdm 4 | transformers 5 | sentence-transformers 6 | pandas 7 | wget 8 | gzip -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /SimRec/scripts/ML-1M/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/zsh 2 | NOW=`date +'%I_%M_%d_%m'` 3 | 4 | EMBEDDING_MODEL=thenlper_gte-large 5 | DATASET_PARTIAL_PATH="../data_preprocessing/ML-1M/ml-1m" 6 | DATASET="${DATASET_PARTIAL_PATH}.txt" 7 | ITEM_FREQ="${DATASET_PARTIAL_PATH}-train_item_freq.txt" 8 | SIMILARITY_INDICES="${DATASET_PARTIAL_PATH}-similarity-indices-${EMBEDDING_MODEL}.pt" 9 | SIMILARITY_VALUES="${DATASET_PARTIAL_PATH}-similarity-values-${EMBEDDING_MODEL}.pt" 10 | 11 | SIMILARITY_THREHOLD=0.5 12 | TEMPERATURE=1 13 | LAMBDA=0.6 14 | LAMBDA_SCHEDULING=LINEAR 15 | LAMBDA_WARMPUP=1000 16 | LAMBDA_STEPS=7000 17 | MAX_LEN=200 18 | BATCH_SIZE=128 19 | LR=0.001 20 | DROPOUT=0.5 21 | NUM_BLOCKS=3 22 | EPOCHS=200 23 | DEVICE="cuda:0" 24 | HIDDEN_DIM=100 25 | TRAIN_DIR="results/ML-1M/${NOW}" 26 | 27 | python main.py --dataset ${DATASET}\ 28 | --item_frequency ${ITEM_FREQ}\ 29 | --similarity_indices ${SIMILARITY_INDICES}\ 30 | --similarity_values ${SIMILARITY_VALUES}\ 31 | --similarity_threshold ${SIMILARITY_THREHOLD}\ 32 | --temperature ${TEMPERATURE}\ 33 | --lambd ${LAMBDA}\ 34 | --lambd_scheduling "${LAMBDA_SCHEDULING}"\ 35 | --lambd_warmup_steps ${LAMBDA_WARMPUP}\ 36 | --lambd_steps ${LAMBDA_STEPS}\ 37 | --batch_size ${BATCH_SIZE}\ 38 | --lr ${LR}\ 39 | --maxlen ${MAX_LEN}\ 40 | --dropout_rate ${DROPOUT}\ 41 | --num_blocks ${NUM_BLOCKS}\ 42 | --num_epochs ${EPOCHS}\ 43 | --hidden_units ${HIDDEN_DIM}\ 44 | --train_dir ${TRAIN_DIR}\ 45 | --device ${DEVICE} -------------------------------------------------------------------------------- /SimRec/scripts/Steam/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/zsh 2 | NOW=`date +'%I_%M_%d_%m'` 3 | EMBEDDING_MODEL=thenlper_gte-large 4 | DATASET_PARTIAL_PATH="../data_preprocessing/Steam/steam" 5 | DATASET="${DATASET_PARTIAL_PATH}.txt" 6 | ITEM_FREQ="${DATASET_PARTIAL_PATH}-train_item_freq.txt" 7 | SIMILARITY_INDICES="${DATASET_PARTIAL_PATH}-similarity-indices-${EMBEDDING_MODEL}.pt" 8 | SIMILARITY_VALUES="${DATASET_PARTIAL_PATH}-similarity-values-${EMBEDDING_MODEL}.pt" 9 | 10 | 11 | SIMILARITY_THREHOLD=0.6 12 | TEMPERATURE=1.5 13 | LAMBDA=0.2 14 | LAMBDA_SCHEDULING=LINEAR 15 | LAMBDA_WARMPUP=10000 16 | LAMBDA_STEPS=50000 17 | MAX_LEN=50 18 | BATCH_SIZE=128 19 | LR=0.001 20 | DROPOUT=0.5 21 | NUM_BLOCKS=2 22 | EPOCHS=200 23 | DEVICE="cuda:0" 24 | HIDDEN_DIM=100 25 | TRAIN_DIR="results/steam/${NOW}" 26 | 27 | python main.py --dataset ${DATASET}\ 28 | --item_frequency ${ITEM_FREQ}\ 29 | --similarity_indices ${SIMILARITY_INDICES}\ 30 | --similarity_values ${SIMILARITY_VALUES}\ 31 | --similarity_threshold ${SIMILARITY_THREHOLD}\ 32 | --temperature ${TEMPERATURE}\ 33 | --lambd ${LAMBDA}\ 34 | --lambd_scheduling "${LAMBDA_SCHEDULING}"\ 35 | --lambd_warmup_steps ${LAMBDA_WARMPUP}\ 36 | --lambd_steps ${LAMBDA_STEPS}\ 37 | --batch_size ${BATCH_SIZE}\ 38 | --lr ${LR}\ 39 | --maxlen ${MAX_LEN}\ 40 | --dropout_rate ${DROPOUT}\ 41 | --num_blocks ${NUM_BLOCKS}\ 42 | --num_epochs ${EPOCHS}\ 43 | --hidden_units ${HIDDEN_DIM}\ 44 | --train_dir ${TRAIN_DIR}\ 45 | --device ${DEVICE} -------------------------------------------------------------------------------- /SimRec/scripts/Tools/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/zsh 2 | NOW=`date +'%I_%M_%d_%m'` 3 | 4 | EMBEDDING_MODEL=thenlper_gte-large 5 | DATASET_PARTIAL_PATH="../data_preprocessing/Tools/Tools" 6 | DATASET="${DATASET_PARTIAL_PATH}.txt" 7 | ITEM_FREQ="${DATASET_PARTIAL_PATH}-train_item_freq.txt" 8 | SIMILARITY_INDICES="${DATASET_PARTIAL_PATH}-similarity-indices-${EMBEDDING_MODEL}.pt" 9 | SIMILARITY_VALUES="${DATASET_PARTIAL_PATH}-similarity-values-${EMBEDDING_MODEL}.pt" 10 | 11 | SIMILARITY_THREHOLD=0.4 12 | TEMPERATURE=0.5 13 | LAMBDA=0.8 14 | LAMBDA_SCHEDULING=LINEAR 15 | LAMBDA_WARMPUP=1000 16 | LAMBDA_STEPS=60000 17 | MAX_LEN=50 18 | BATCH_SIZE=128 19 | LR=0.0001 20 | DROPOUT=0.5 21 | NUM_BLOCKS=3 22 | EPOCHS=200 23 | DEVICE="cuda:0" 24 | HIDDEN_DIM=50 25 | TRAIN_DIR="results/tools/${NOW}" 26 | 27 | python main.py --dataset ${DATASET}\ 28 | --item_frequency ${ITEM_FREQ}\ 29 | --similarity_indices ${SIMILARITY_INDICES}\ 30 | --similarity_values ${SIMILARITY_VALUES}\ 31 | --similarity_threshold ${SIMILARITY_THREHOLD}\ 32 | --temperature ${TEMPERATURE}\ 33 | --lambd ${LAMBDA}\ 34 | --lambd_scheduling "${LAMBDA_SCHEDULING}"\ 35 | --lambd_warmup_steps ${LAMBDA_WARMPUP}\ 36 | --lambd_steps ${LAMBDA_STEPS}\ 37 | --batch_size ${BATCH_SIZE}\ 38 | --lr ${LR}\ 39 | --maxlen ${MAX_LEN}\ 40 | --dropout_rate ${DROPOUT}\ 41 | --num_blocks ${NUM_BLOCKS}\ 42 | --num_epochs ${EPOCHS}\ 43 | --hidden_units ${HIDDEN_DIM}\ 44 | --train_dir ${TRAIN_DIR}\ 45 | --device ${DEVICE} 46 | -------------------------------------------------------------------------------- /SimRec/scripts/Beauty/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/zsh 2 | NOW=`date +'%I_%M_%d_%m'` 3 | 4 | EMBEDDING_MODEL=thenlper_gte-large 5 | DATASET_PARTIAL_PATH="../data_preprocessing/Beauty/Beauty" 6 | DATASET="${DATASET_PARTIAL_PATH}.txt" 7 | ITEM_FREQ="${DATASET_PARTIAL_PATH}-train_item_freq.txt" 8 | SIMILARITY_INDICES="${DATASET_PARTIAL_PATH}-similarity-indices-${EMBEDDING_MODEL}.pt" 9 | SIMILARITY_VALUES="${DATASET_PARTIAL_PATH}-similarity-values-${EMBEDDING_MODEL}.pt" 10 | 11 | SIMILARITY_THREHOLD=0.9 12 | TEMPERATURE=0.5 13 | LAMBDA=0.3 14 | LAMBDA_SCHEDULING=LINEAR 15 | LAMBDA_WARMPUP=1000 16 | LAMBDA_STEPS=81000 17 | MAX_LEN=50 18 | BATCH_SIZE=128 19 | LR=0.0001 20 | DROPOUT=0.5 21 | NUM_BLOCKS=3 22 | EPOCHS=210 23 | DEVICE="cuda:0" 24 | HIDDEN_DIM=100 25 | TRAIN_DIR="results/beauty/${NOW}" 26 | 27 | python main.py --dataset ${DATASET}\ 28 | --item_frequency ${ITEM_FREQ}\ 29 | --similarity_indices ${SIMILARITY_INDICES}\ 30 | --similarity_values ${SIMILARITY_VALUES}\ 31 | --similarity_threshold ${SIMILARITY_THREHOLD}\ 32 | --temperature ${TEMPERATURE}\ 33 | --lambd ${LAMBDA}\ 34 | --lambd_scheduling "${LAMBDA_SCHEDULING}"\ 35 | --lambd_warmup_steps ${LAMBDA_WARMPUP}\ 36 | --lambd_steps ${LAMBDA_STEPS}\ 37 | --batch_size ${BATCH_SIZE}\ 38 | --lr ${LR}\ 39 | --maxlen ${MAX_LEN}\ 40 | --dropout_rate ${DROPOUT}\ 41 | --num_blocks ${NUM_BLOCKS}\ 42 | --num_epochs ${EPOCHS}\ 43 | --hidden_units ${HIDDEN_DIM}\ 44 | --train_dir ${TRAIN_DIR}\ 45 | --device ${DEVICE} 46 | -------------------------------------------------------------------------------- /SimRec/scripts/PetSupplies/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/zsh 2 | NOW=`date +'%I_%M_%d_%m'` 3 | 4 | EMBEDDING_MODEL=thenlper_gte-large 5 | DATASET_PARTIAL_PATH="../data_preprocessing/PetSupplies/Pet" 6 | DATASET="${DATASET_PARTIAL_PATH}.txt" 7 | ITEM_FREQ="${DATASET_PARTIAL_PATH}-train_item_freq.txt" 8 | SIMILARITY_INDICES="${DATASET_PARTIAL_PATH}-similarity-indices-${EMBEDDING_MODEL}.pt" 9 | SIMILARITY_VALUES="${DATASET_PARTIAL_PATH}-similarity-values-${EMBEDDING_MODEL}.pt" 10 | 11 | 12 | SIMILARITY_THREHOLD=0.7 13 | TEMPERATURE=0.5 14 | LAMBDA=0.6 15 | LAMBDA_SCHEDULING=LINEAR 16 | LAMBDA_WARMPUP=1000 17 | LAMBDA_STEPS=70000 18 | MAX_LEN=50 19 | BATCH_SIZE=128 20 | LR=0.0001 21 | DROPOUT=0.5 22 | NUM_BLOCKS=2 23 | EPOCHS=200 24 | DEVICE="cuda:0" 25 | HIDDEN_DIM=50 26 | TRAIN_DIR="results/pet/${NOW}" 27 | 28 | python main.py --dataset ${DATASET}\ 29 | --item_frequency ${ITEM_FREQ}\ 30 | --similarity_indices ${SIMILARITY_INDICES}\ 31 | --similarity_values ${SIMILARITY_VALUES}\ 32 | --similarity_threshold ${SIMILARITY_THREHOLD}\ 33 | --temperature ${TEMPERATURE}\ 34 | --lambd ${LAMBDA}\ 35 | --lambd_scheduling "${LAMBDA_SCHEDULING}"\ 36 | --lambd_warmup_steps ${LAMBDA_WARMPUP}\ 37 | --lambd_steps ${LAMBDA_STEPS}\ 38 | --batch_size ${BATCH_SIZE}\ 39 | --lr ${LR}\ 40 | --maxlen ${MAX_LEN}\ 41 | --dropout_rate ${DROPOUT}\ 42 | --num_blocks ${NUM_BLOCKS}\ 43 | --num_epochs ${EPOCHS}\ 44 | --hidden_units ${HIDDEN_DIM}\ 45 | --train_dir ${TRAIN_DIR}\ 46 | --device ${DEVICE} -------------------------------------------------------------------------------- /SimRec/scripts/HomeKitchen/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/zsh 2 | NOW=`date +'%I_%M_%d_%m'` 3 | 4 | EMBEDDING_MODEL=thenlper_gte-large 5 | DATASET_PARTIAL_PATH="../data_preprocessing/HomeKitchen/HomeKitchen" 6 | DATASET="${DATASET_PARTIAL_PATH}.txt" 7 | ITEM_FREQ="${DATASET_PARTIAL_PATH}-train_item_freq.txt" 8 | SIMILARITY_INDICES="${DATASET_PARTIAL_PATH}-similarity-indices-${EMBEDDING_MODEL}.pt" 9 | SIMILARITY_VALUES="${DATASET_PARTIAL_PATH}-similarity-values-${EMBEDDING_MODEL}.pt" 10 | 11 | SIMILARITY_THREHOLD=0.6 12 | TEMPERATURE=0.5 13 | LAMBDA=0.7 14 | LAMBDA_SCHEDULING=LINEAR 15 | LAMBDA_WARMPUP=1000 16 | LAMBDA_STEPS=160000 17 | MAX_LEN=50 18 | BATCH_SIZE=128 19 | LR=0.0001 20 | DROPOUT=0.5 21 | NUM_BLOCKS=3 22 | EPOCHS=210 23 | DEVICE="cuda:0" 24 | HIDDEN_DIM=50 25 | TRAIN_DIR="results/homekitchen/${NOW}" 26 | 27 | python main.py --dataset ${DATASET}\ 28 | --item_frequency ${ITEM_FREQ}\ 29 | --similarity_indices ${SIMILARITY_INDICES}\ 30 | --similarity_values ${SIMILARITY_VALUES}\ 31 | --similarity_threshold ${SIMILARITY_THREHOLD}\ 32 | --temperature ${TEMPERATURE}\ 33 | --lambd ${LAMBDA}\ 34 | --lambd_scheduling "${LAMBDA_SCHEDULING}"\ 35 | --lambd_warmup_steps ${LAMBDA_WARMPUP}\ 36 | --lambd_steps ${LAMBDA_STEPS}\ 37 | --batch_size ${BATCH_SIZE}\ 38 | --lr ${LR}\ 39 | --maxlen ${MAX_LEN}\ 40 | --dropout_rate ${DROPOUT}\ 41 | --num_blocks ${NUM_BLOCKS}\ 42 | --num_epochs ${EPOCHS}\ 43 | --hidden_units ${HIDDEN_DIM}\ 44 | --train_dir ${TRAIN_DIR}\ 45 | --device ${DEVICE} -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SimRec: Mitigating the Cold-Start Problem in Sequential Recommendation by Integrating Item Similarity 2 | This repository is the official implementation of the paper "SimRec: Mitigating the Cold-Start Problem in Sequential Recommendation by Integrating Item Similarity". 3 | Access the full paper here: https://www.amazon.science/publications/simrec-mitigating-the-cold-start-problem-in-sequential-recommendation-by-integrating-item-similarity 4 | 5 | ## Datasets 6 | 7 | [data_preprocessing](/data_preprocessing/) contains the code for generating the datasets. 8 | 9 | This includes 2 steps: 10 | 1. Generating the dataset, by using the jupyter notebook `preprocessing_data.ipynb` that can be found in each sub-directory. 11 | 2. Calcualting the similarity scores using the jupyter notebook [`calculate_similarity_scores.ipynb`](/data_preprocessing/calculate_similarity_scores.ipynb) 12 | 13 | After you create the dataset(s) you can move to training the model. 14 | 15 | ## Model Training 16 | 17 | [SimRec](/SimRec/) contains the code for training SimRec on the generated datasets. 18 | 19 | To run SimRec: 20 | ```bash 21 | cd SimRec 22 | ``` 23 | 24 | Train the model on Beauty dataset: 25 | ``` 26 | bash SimRec/scripts/Beauty/train.sh 27 | ``` 28 | Train the model on Tools dataset: 29 | ``` 30 | bash SimRec/scripts/Tools/train.sh 31 | ``` 32 | Train the model on Pet Supplies dataset: 33 | ``` 34 | bash SimRec/scripts/PetSupplies/train.sh 35 | ``` 36 | Train the model on Home & Kitchen dataset: 37 | ``` 38 | bash SimRec/scripts/HomeKitchen/train.sh 39 | ``` 40 | Train the model on ML-1M dataset: 41 | ``` 42 | bash SimRec/scripts/ML-1M/train.sh 43 | ``` 44 | Train the model on Steam dataset: 45 | ``` 46 | bash SimRec/scripts/Steam/train.sh 47 | ``` 48 | 49 | Note that the implementation of SimRec is based on [SASRec.pytorch](https://github.com/pmixer/SASRec.pytorch) repo. 50 | 51 | ## Requirements 52 | 53 | ``` 54 | pip install -r requirements 55 | ``` 56 | 57 | ## Hardware 58 | In general, all experiments can run on either GPU or CPU. 59 | 60 | ## License 61 | This project is licensed under the Apache-2.0 License. -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 13 | 14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already 15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 16 | 17 | * A reproducible test case or series of steps 18 | * The version of our code being used 19 | * Any modifications you've made relevant to the bug 20 | * Anything unusual about your environment or deployment 21 | 22 | 23 | ## Contributing via Pull Requests 24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 25 | 26 | 1. You are working against the latest source on the *main* branch. 27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 29 | 30 | To send us a pull request, please: 31 | 32 | 1. Fork the repository. 33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 34 | 3. Ensure local tests pass. 35 | 4. Commit to your fork using clear commit messages. 36 | 5. Send us a pull request, answering any default questions in the pull request interface. 37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 38 | 39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 41 | 42 | 43 | ## Finding contributions to work on 44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start. 45 | 46 | 47 | ## Code of Conduct 48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 50 | opensource-codeofconduct@amazon.com with any additional questions or comments. 51 | 52 | 53 | ## Security issue notifications 54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 55 | 56 | 57 | ## Licensing 58 | 59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 60 | -------------------------------------------------------------------------------- /SimRec/model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from utils import * 4 | 5 | 6 | class PointWiseFeedForward(torch.nn.Module): 7 | def __init__(self, hidden_units, dropout_rate): 8 | 9 | super(PointWiseFeedForward, self).__init__() 10 | 11 | self.conv1 = torch.nn.Conv1d(hidden_units, hidden_units, kernel_size=1) 12 | self.dropout1 = torch.nn.Dropout(p=dropout_rate) 13 | self.relu = torch.nn.ReLU() 14 | self.conv2 = torch.nn.Conv1d(hidden_units, hidden_units, kernel_size=1) 15 | self.dropout2 = torch.nn.Dropout(p=dropout_rate) 16 | 17 | def forward(self, inputs): 18 | outputs = self.dropout2(self.conv2(self.relu(self.dropout1(self.conv1(inputs.transpose(-1, -2)))))) 19 | outputs = outputs.transpose(-1, -2) # as Conv1D requires (N, C, Length) 20 | outputs += inputs 21 | return outputs 22 | 23 | # pls use the following self-made multihead attention layer 24 | # in case your pytorch version is below 1.16 or for other reasons 25 | # https://github.com/pmixer/TiSASRec.pytorch/blob/master/model.py 26 | 27 | class SimRec(torch.nn.Module): 28 | def __init__(self, user_num, item_num, args): 29 | super(SimRec, self).__init__() 30 | 31 | self.user_num = user_num 32 | self.item_num = item_num 33 | self.dev = args.device 34 | self.loss = args.loss 35 | self.training_mode = args.training_mode 36 | 37 | # TODO: loss += args.l2_emb for regularizing embedding vectors during training 38 | # https://stackoverflow.com/questions/42704283/adding-l1-l2-regularization-in-pytorch 39 | self.item_emb = torch.nn.Embedding(self.item_num+1, args.hidden_units, padding_idx=0) 40 | self.pos_emb = torch.nn.Embedding(args.maxlen, args.hidden_units) # TO IMPROVE 41 | self.emb_dropout = torch.nn.Dropout(p=args.dropout_rate) 42 | 43 | 44 | self.attention_layernorms = torch.nn.ModuleList() # to be Q for self-attention 45 | self.attention_layers = torch.nn.ModuleList() 46 | self.forward_layernorms = torch.nn.ModuleList() 47 | self.forward_layers = torch.nn.ModuleList() 48 | 49 | self.last_layernorm = torch.nn.LayerNorm(args.hidden_units, eps=1e-8) 50 | 51 | for _ in range(args.num_blocks): 52 | new_attn_layernorm = torch.nn.LayerNorm(args.hidden_units, eps=1e-8) 53 | self.attention_layernorms.append(new_attn_layernorm) 54 | 55 | new_attn_layer = torch.nn.MultiheadAttention(args.hidden_units, 56 | args.num_heads, 57 | args.dropout_rate) 58 | self.attention_layers.append(new_attn_layer) 59 | 60 | new_fwd_layernorm = torch.nn.LayerNorm(args.hidden_units, eps=1e-8) 61 | self.forward_layernorms.append(new_fwd_layernorm) 62 | 63 | new_fwd_layer = PointWiseFeedForward(args.hidden_units, args.dropout_rate) 64 | self.forward_layers.append(new_fwd_layer) 65 | 66 | 67 | self.reset_parameters() 68 | 69 | def reset_parameters(self): 70 | for name, param in self.named_parameters(): 71 | try: 72 | torch.nn.init.xavier_normal_(param.data) 73 | except: 74 | pass # just ignore those failed init layers 75 | 76 | def log2feats(self, log_seqs): 77 | if not torch.is_tensor(log_seqs): 78 | log_seqs = torch.LongTensor(log_seqs).to(self.dev) 79 | seqs = self.item_emb(log_seqs) 80 | seqs *= self.item_emb.embedding_dim ** 0.5 81 | positions = np.tile(np.array(range(log_seqs.shape[1])), [log_seqs.shape[0], 1]) 82 | seqs += self.pos_emb(torch.LongTensor(positions).to(self.dev)) 83 | seqs = self.emb_dropout(seqs) 84 | 85 | timeline_mask = log_seqs == 0 86 | seqs *= ~timeline_mask.unsqueeze(-1) # broadcast in last dim 87 | 88 | tl = seqs.shape[1] # time dim len for enforce causality 89 | attention_mask = ~torch.tril(torch.ones((tl, tl), dtype=torch.bool, device=self.dev)) 90 | 91 | for i in range(len(self.attention_layers)): 92 | seqs = torch.transpose(seqs, 0, 1) 93 | Q = self.attention_layernorms[i](seqs) 94 | mha_outputs, _ = self.attention_layers[i](Q, seqs, seqs, 95 | attn_mask=attention_mask) 96 | # key_padding_mask=timeline_mask 97 | # need_weights=False) this arg do not work? 98 | seqs = Q + mha_outputs 99 | seqs = torch.transpose(seqs, 0, 1) 100 | 101 | seqs = self.forward_layernorms[i](seqs) 102 | seqs = self.forward_layers[i](seqs) 103 | seqs *= ~timeline_mask.unsqueeze(-1) 104 | 105 | log_feats = self.last_layernorm(seqs) # (U, T, C) -> (U, -1, C) 106 | 107 | return log_feats 108 | 109 | def forward(self, user_ids, log_seqs, pos_seqs, neg_seqs): # for training 110 | # (batch_size, max_len, hidden_dim) ? 111 | log_feats = self.log2feats(log_seqs) # user_ids hasn't been used yet 112 | pos_logits = None 113 | neg_logits = None 114 | logits = None 115 | 116 | if not torch.is_tensor(pos_seqs): 117 | pos_seqs = torch.LongTensor(pos_seqs).to(self.dev) 118 | if not torch.is_tensor(neg_seqs): 119 | neg_seqs = torch.LongTensor(neg_seqs).to(self.dev) 120 | pos_embs = self.item_emb(pos_seqs) 121 | neg_embs = self.item_emb(neg_seqs) 122 | 123 | pos_logits = (log_feats * pos_embs).sum(dim=-1) 124 | neg_logits = (log_feats * neg_embs).sum(dim=-1) 125 | 126 | embed = self.item_emb.weight 127 | logits = log_feats @ embed.T 128 | return pos_logits, neg_logits, logits 129 | 130 | 131 | 132 | def predict(self, user_ids, log_seqs, item_indices): # for inference 133 | log_feats = self.log2feats(log_seqs) # user_ids hasn't been used yet 134 | 135 | final_feat = log_feats[:, -1, :] # only use last QKV classifier, a waste 136 | 137 | if not torch.is_tensor(item_indices): 138 | item_indices = torch.LongTensor(item_indices).to(self.dev) 139 | 140 | item_embs = self.item_emb(item_indices) # (U, I, C) 141 | 142 | logits = item_embs.matmul(final_feat.unsqueeze(-1)).squeeze(-1) 143 | 144 | return logits # preds # (U, I) 145 | -------------------------------------------------------------------------------- /data_preprocessing/ML-1M/preprocessing_data.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from collections import defaultdict\n", 10 | "import os\n", 11 | "import wget\n", 12 | "import zipfile\n", 13 | "from tqdm.notebook import tqdm" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": null, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "url = \"https://files.grouplens.org/datasets/movielens/ml-1m.zip\"\n", 23 | "zip_path = \"ml-1m.zip\"\n", 24 | "raw_dir =\"ml-1m\"\n", 25 | "items_path = os.path.join(raw_dir,\"movies.dat\")\n", 26 | "ratings_path = os.path.join(raw_dir,\"ratings.dat\")\n", 27 | "\n", 28 | "# download raw dataset\n", 29 | "if not os.path.exists(zip_path):\n", 30 | " wget.download(url)\n", 31 | "\n", 32 | "if not os.path.exists(raw_dir):\n", 33 | " with zipfile.ZipFile(zip_path,\"r\") as zip_ref:\n", 34 | " zip_ref.extractall()" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": null, 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "ITEM_FREQ_MIN = 5\n", 44 | "REVIEWS_REMOVE_LESS_THAN = 5\n", 45 | "\n", 46 | "SEP=\"::\"\n", 47 | "INTERNAL_SEP=\"|\"\n", 48 | "\n", 49 | "out_path = \"ml-1m.txt\"\n", 50 | "id_to_title_map_path = \"ml-1m-titles.txt\"\n", 51 | "train_item_freq_path = \"ml-1m-train_item_freq.txt\"\n" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": null, 57 | "metadata": {}, 58 | "outputs": [], 59 | "source": [ 60 | "# load items data\n", 61 | "items = dict()\n", 62 | "with open(items_path, \"r\", encoding ='ISO-8859-1') as f:\n", 63 | " for line in f:\n", 64 | " item_id, title, genres = line.split(SEP)\n", 65 | " items[item_id] = title\n", 66 | "\n", 67 | "print(f\"Found {len(items)} items\")\n", 68 | "\n", 69 | "# load rewview data\n", 70 | "reviews = defaultdict(list)\n", 71 | "item_freq = defaultdict(int)\n", 72 | "skipped = 0\n", 73 | "with open(ratings_path, \"r\", encoding=\"utf-8\") as f:\n", 74 | " for line in f:\n", 75 | " user_id, item_id, rating, timestemp = line.split(SEP)\n", 76 | " if item_id in items:\n", 77 | " reviews[user_id].append((item_id, int(timestemp)))\n", 78 | " item_freq[item_id] += 1\n", 79 | " else:\n", 80 | " skipped += 1\n", 81 | "\n", 82 | "print(f\"Found {len(reviews)} users\")\n", 83 | "print(f\"Found {sum(item_freq.values())} reviews\")\n", 84 | "print(f\"Skipepd {skipped} item reviews without metadata\")\n", 85 | " \n", 86 | "item_freq = {k: v for k, v in item_freq.items() if v >= ITEM_FREQ_MIN}\n", 87 | "\n", 88 | "# remove user with less than K reviews\n", 89 | "removed_users_less_than = 0\n", 90 | "removed_users_item_less_than = 0\n", 91 | "removed_items = 0\n", 92 | "updated_items = set()\n", 93 | "for user_id in list(reviews.keys()):\n", 94 | " if len(reviews[user_id]) < REVIEWS_REMOVE_LESS_THAN:\n", 95 | " del reviews[user_id]\n", 96 | " removed_users_less_than += 1\n", 97 | " else:\n", 98 | " len_before = len(reviews[user_id])\n", 99 | " reviews[user_id] = [item for item in reviews[user_id] if item[0] in item_freq]\n", 100 | " updated_items.update([t[0] for t in reviews[user_id]])\n", 101 | " removed_items += len_before - len(reviews[user_id])\n", 102 | " if len(reviews[user_id]) <= 0:\n", 103 | " del reviews[user_id]\n", 104 | " removed_users_item_less_than += 1\n", 105 | "print(f\"Removed {removed_items} reviews of items that appear less than {ITEM_FREQ_MIN} in total\")\n", 106 | "print(f\"Removed {removed_users_less_than} users with less than {REVIEWS_REMOVE_LESS_THAN} actions\")\n", 107 | "print(f\"Removed {removed_users_item_less_than} users with only item count less than {REVIEWS_REMOVE_LESS_THAN}\")\n", 108 | "\n", 109 | "# calculate item frequencey again \n", 110 | "original_item_freq = item_freq\n", 111 | "item_freq = defaultdict(int)\n", 112 | "for user_id, rating_list in reviews.items():\n", 113 | " for item, timestamp in rating_list:\n", 114 | " item_freq[item] += 1\n", 115 | " \n", 116 | "item_freq = dict(sorted(item_freq.items()))\n", 117 | "print(f\"Total of {sum(item_freq.values())} reviews\")\n", 118 | "\n", 119 | "# remove \"unused\" items\n", 120 | "new_items = {}\n", 121 | "new_item_freq = {}\n", 122 | "new_original_item_freq = {}\n", 123 | "for asin in tqdm(updated_items):\n", 124 | " new_items[asin] = items[asin]\n", 125 | " new_item_freq[asin] = item_freq[asin]\n", 126 | " new_original_item_freq[asin] = original_item_freq[asin]\n", 127 | "print(f\"Removed {len(items) - len(new_items)} items that are not been reviewd\")\n", 128 | "item_freq = new_item_freq\n", 129 | "items = new_items\n", 130 | "original_item_freq = new_original_item_freq\n", 131 | "\n", 132 | "\n", 133 | "print()\n", 134 | "print(f\"Items Reviews Users\")\n", 135 | "print(f\"{len(items):<4} {sum(len(v) for v in reviews.values()):<7} {len(reviews):<5}\")\n", 136 | "\n", 137 | "# fix user id\n", 138 | "user_id_mapping = dict()\n", 139 | "i = 0\n", 140 | "for original_user_id in reviews:\n", 141 | " user_id_mapping[original_user_id] = i\n", 142 | " i += 1\n", 143 | "\n", 144 | "# fix items ids\n", 145 | "item_id_mapping = dict()\n", 146 | "i = 0\n", 147 | "for asin in items:\n", 148 | " item_id_mapping[asin] = i\n", 149 | " i += 1\n", 150 | "\n", 151 | "train_item_freq = {k: 0 for k in item_freq.keys()}\n", 152 | "val_item_freq = {k: 0 for k in item_freq.keys()}\n", 153 | "test_item_freq = {k: 0 for k in item_freq.keys()}\n", 154 | "for user_id, rating_list in reviews.items():\n", 155 | " sorted_list = list(map(lambda t: t[0], sorted(rating_list, key=lambda t: t[1])))\n", 156 | " if len(sorted_list) < 3:\n", 157 | " train_list = sorted_list\n", 158 | " else:\n", 159 | " train_list = sorted_list[1:-2]\n", 160 | " val_item_freq[sorted_list[-2]] += 1\n", 161 | " test_item_freq[sorted_list[-1]] += 1 \n", 162 | " for asin in train_list:\n", 163 | " train_item_freq[asin] += 1\n", 164 | "\n", 165 | "with open(out_path, \"w\") as f:\n", 166 | " for user_id, rating_list in reviews.items():\n", 167 | " sorted_list = sorted(rating_list, key=lambda t: t[1])\n", 168 | " for item_id, timestamp in sorted_list:\n", 169 | " f.write(f\"{user_id_mapping[user_id] + 1} {item_id_mapping[item_id] + 1}\\n\") # start user id from 1 to match original SASRec paper,reserve the 0 index for padding\n", 170 | "\n", 171 | "with open(id_to_title_map_path, \"w\") as f:\n", 172 | " for asin, title in items.items():\n", 173 | " f.write(f'{item_id_mapping[asin]} \"{title}\"\\n')\n", 174 | "\n", 175 | "with open(train_item_freq_path, \"w\") as f:\n", 176 | " for asin, count in train_item_freq.items():\n", 177 | " f.write(f'{item_id_mapping[asin]} {count}\\n')" 178 | ] 179 | } 180 | ], 181 | "metadata": { 182 | "kernelspec": { 183 | "display_name": "py39", 184 | "language": "python", 185 | "name": "python3" 186 | }, 187 | "language_info": { 188 | "codemirror_mode": { 189 | "name": "ipython", 190 | "version": 3 191 | }, 192 | "file_extension": ".py", 193 | "mimetype": "text/x-python", 194 | "name": "python", 195 | "nbconvert_exporter": "python", 196 | "pygments_lexer": "ipython3", 197 | "version": "3.9.17" 198 | }, 199 | "orig_nbformat": 4 200 | }, 201 | "nbformat": 4, 202 | "nbformat_minor": 2 203 | } 204 | -------------------------------------------------------------------------------- /data_preprocessing/Steam/preprocessing_data.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from collections import defaultdict\n", 10 | "import os\n", 11 | "from datetime import datetime\n", 12 | "import time\n", 13 | "import wget\n", 14 | "import gzip\n", 15 | "from tqdm.notebook import tqdm" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": null, 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [ 24 | "base_url = \"http://cseweb.ucsd.edu/~wckang/\"\n", 25 | "reviews_path = \"steam_reviews.json.gz\"\n", 26 | "meta_path = \"steam_games.json.gz\"\n", 27 | "\n", 28 | "# download raw dataset\n", 29 | "if not os.path.exists(reviews_path):\n", 30 | " wget.download(base_url + reviews_path)\n", 31 | "if not os.path.exists(meta_path):\n", 32 | " wget.download(base_url + meta_path)" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "ITEM_FREQ_MIN = 5\n", 42 | "REVIEWS_REMOVE_LESS_THAN = 5\n", 43 | "\n", 44 | "out_path = \"steam.txt\"\n", 45 | "id_to_title_map_path = \"steam-titles.txt\"\n", 46 | "train_item_freq_path = \"steam-train_item_freq.txt\"" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": null, 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "# load items data\n", 56 | "items = dict()\n", 57 | "skipped = 0\n", 58 | "with gzip.open(meta_path, \"r\") as f:\n", 59 | " for line in tqdm(f):\n", 60 | " json_obj = eval(line)\n", 61 | " if 'title' in json_obj and 'id' in json_obj:\n", 62 | " asin = json_obj['id']\n", 63 | " title = json_obj['title'].replace(\"\\\"\", \"'\")\n", 64 | " title = title.replace(\"\\n\", \" \")\n", 65 | " if len(title) >= 2:\n", 66 | " items[asin] = title\n", 67 | " else:\n", 68 | " skipped +=1\n", 69 | " else:\n", 70 | " skipped += 1\n", 71 | "\n", 72 | "print(f\"Found {len(items)} items\")\n", 73 | "print(f\"Skipped {skipped} items without id or title\")\n", 74 | "\n", 75 | "# load reviews data\n", 76 | "reviews = defaultdict(list)\n", 77 | "item_freq = defaultdict(int)\n", 78 | "skipped = 0\n", 79 | "with gzip.open(reviews_path, \"r\") as f:\n", 80 | " for line in tqdm(f):\n", 81 | " json_obj = eval(line)\n", 82 | " user_id = json_obj['username']\n", 83 | " asin = json_obj['product_id']\n", 84 | " timestemp = json_obj['date']\n", 85 | " if asin in items:\n", 86 | " date_time = datetime.strptime(timestemp, \"%Y-%m-%d\")\n", 87 | " unix_timestamp = int(time.mktime(date_time.timetuple()))\n", 88 | " reviews[user_id].append((asin, unix_timestamp))\n", 89 | " item_freq[asin] += 1\n", 90 | " else:\n", 91 | " skipped += 1\n", 92 | " # print(f\"skipped {asin}\")\n", 93 | "\n", 94 | "print(f\"Found {len(reviews)} users\")\n", 95 | "print(f\"Found {sum(item_freq.values())} reviews\")\n", 96 | "print(f\"Skipepd {skipped} item reviews without metadata\")\n", 97 | "\n", 98 | " \n", 99 | "item_freq = {k: v for k, v in item_freq.items() if v >= ITEM_FREQ_MIN}\n", 100 | "\n", 101 | "# remove user with less than K reviews\n", 102 | "removed_users_less_than = 0\n", 103 | "removed_users_item_less_than = 0\n", 104 | "removed_items = 0\n", 105 | "updated_items = set()\n", 106 | "for user_id in list(reviews.keys()):\n", 107 | " if len(reviews[user_id]) < REVIEWS_REMOVE_LESS_THAN:\n", 108 | " del reviews[user_id]\n", 109 | " removed_users_less_than += 1\n", 110 | " else:\n", 111 | " len_before = len(reviews[user_id])\n", 112 | " reviews[user_id] = [item for item in reviews[user_id] if item[0] in item_freq]\n", 113 | " updated_items.update([t[0] for t in reviews[user_id]])\n", 114 | " removed_items += len_before - len(reviews[user_id])\n", 115 | " if len(reviews[user_id]) <= 0:\n", 116 | " del reviews[user_id]\n", 117 | " removed_users_item_less_than += 1\n", 118 | "print(f\"Removed {removed_items} reviews of items that appear less than {ITEM_FREQ_MIN} in total\")\n", 119 | "print(f\"Removed {removed_users_less_than} users with less than {REVIEWS_REMOVE_LESS_THAN} actions\")\n", 120 | "print(f\"Removed {removed_users_item_less_than} users with only item count less than {REVIEWS_REMOVE_LESS_THAN}\")\n", 121 | "\n", 122 | "# calculate item frequencey again \n", 123 | "original_item_freq = item_freq\n", 124 | "item_freq = defaultdict(int)\n", 125 | "for user_id, rating_list in reviews.items():\n", 126 | " for item, timestamp in rating_list:\n", 127 | " item_freq[item] += 1\n", 128 | " \n", 129 | "item_freq = dict(sorted(item_freq.items()))\n", 130 | "print(f\"Total of {sum(item_freq.values())} reviews\")\n", 131 | "\n", 132 | "# remove \"unused\" items\n", 133 | "new_items = {}\n", 134 | "new_item_freq = {}\n", 135 | "new_original_item_freq = {}\n", 136 | "for asin in tqdm(updated_items):\n", 137 | " new_items[asin] = items[asin]\n", 138 | " new_item_freq[asin] = item_freq[asin]\n", 139 | " new_original_item_freq[asin] = original_item_freq[asin]\n", 140 | "print(f\"Removed {len(items) - len(new_items)} items that are not been reviewd\")\n", 141 | "item_freq = new_item_freq\n", 142 | "items = new_items\n", 143 | "original_item_freq = new_original_item_freq\n", 144 | "\n", 145 | "\n", 146 | "print()\n", 147 | "print(f\"Items Reviews Users\")\n", 148 | "print(f\"{len(items):<4} {sum(len(v) for v in reviews.values()):<7} {len(reviews):<5}\")\n", 149 | "\n", 150 | "# fix user id\n", 151 | "user_id_mapping = dict()\n", 152 | "i = 0\n", 153 | "for original_user_id in reviews:\n", 154 | " user_id_mapping[original_user_id] = i\n", 155 | " i += 1\n", 156 | "\n", 157 | "# fix items ids\n", 158 | "item_id_mapping = dict()\n", 159 | "i = 0\n", 160 | "for asin in items:\n", 161 | " item_id_mapping[asin] = i\n", 162 | " i += 1\n", 163 | "\n", 164 | "train_item_freq = {k: 0 for k in item_freq.keys()}\n", 165 | "val_item_freq = {k: 0 for k in item_freq.keys()}\n", 166 | "test_item_freq = {k: 0 for k in item_freq.keys()}\n", 167 | "for user_id, rating_list in reviews.items():\n", 168 | " sorted_list = list(map(lambda t: t[0], sorted(rating_list, key=lambda t: t[1])))\n", 169 | " if len(sorted_list) < 3:\n", 170 | " train_list = sorted_list\n", 171 | " else:\n", 172 | " train_list = sorted_list[1:-2]\n", 173 | " val_item_freq[sorted_list[-2]] += 1\n", 174 | " test_item_freq[sorted_list[-1]] += 1 \n", 175 | " for asin in train_list:\n", 176 | " train_item_freq[asin] += 1\n", 177 | "\n", 178 | "with open(out_path, \"w\") as f:\n", 179 | " for user_id, rating_list in reviews.items():\n", 180 | " sorted_list = sorted(rating_list, key=lambda t: t[1])\n", 181 | " for item_id, timestamp in sorted_list:\n", 182 | " f.write(f\"{user_id_mapping[user_id] + 1} {item_id_mapping[item_id] + 1}\\n\") # start user id from 1 to match original SASRec paper,reserve the 0 index for padding\n", 183 | "\n", 184 | "with open(id_to_title_map_path, \"w\") as f:\n", 185 | " for asin, title in items.items():\n", 186 | " f.write(f'{item_id_mapping[asin]} \"{title}\"\\n')\n", 187 | "\n", 188 | "with open(train_item_freq_path, \"w\") as f:\n", 189 | " for asin, count in train_item_freq.items():\n", 190 | " f.write(f'{item_id_mapping[asin]} {count}\\n')" 191 | ] 192 | } 193 | ], 194 | "metadata": { 195 | "kernelspec": { 196 | "display_name": "py39", 197 | "language": "python", 198 | "name": "python3" 199 | }, 200 | "language_info": { 201 | "codemirror_mode": { 202 | "name": "ipython", 203 | "version": 3 204 | }, 205 | "file_extension": ".py", 206 | "mimetype": "text/x-python", 207 | "name": "python", 208 | "nbconvert_exporter": "python", 209 | "pygments_lexer": "ipython3", 210 | "version": "3.9.13" 211 | }, 212 | "orig_nbformat": 4 213 | }, 214 | "nbformat": 4, 215 | "nbformat_minor": 2 216 | } 217 | -------------------------------------------------------------------------------- /data_preprocessing/Beauty/preprocessing_data.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from collections import defaultdict\n", 10 | "import os\n", 11 | "import wget\n", 12 | "import gzip\n", 13 | "from tqdm.notebook import tqdm" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": null, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "base_url = \"http://snap.stanford.edu/data/amazon/productGraph/categoryFiles/\"\n", 23 | "reviews_path = \"reviews_Beauty.json.gz\"\n", 24 | "meta_path = \"meta_Beauty.json.gz\"\n", 25 | "\n", 26 | "# download raw dataset\n", 27 | "if not os.path.exists(reviews_path):\n", 28 | " wget.download(base_url + reviews_path)\n", 29 | "if not os.path.exists(meta_path):\n", 30 | " wget.download(base_url + meta_path)" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": null, 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "ITEM_FREQ_MIN = 5\n", 40 | "REVIEWS_REMOVE_LESS_THAN = 5\n", 41 | "\n", 42 | "out_path = \"Beauty.txt\"\n", 43 | "id_to_title_map_path = \"Beauty-titles.txt\"\n", 44 | "id_to_asin_map_path = \"Beauty-id_to_asin.txt\"\n", 45 | "train_item_freq_path = \"Beauty-train_item_freq.txt\"" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": null, 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "# load items data\n", 55 | "items = dict()\n", 56 | "skipped = 0\n", 57 | "with gzip.open(meta_path, \"r\") as f:\n", 58 | " for line in tqdm(f):\n", 59 | " json_obj = eval(line)\n", 60 | " asin = json_obj['asin']\n", 61 | " if 'title' in json_obj:\n", 62 | " title = json_obj['title'].replace(\"\\\"\", \"'\")\n", 63 | " title = title.replace(\"\\n\", \" \")\n", 64 | " if len(title) >= 2:\n", 65 | " items[asin] = title\n", 66 | " else:\n", 67 | " skipped +=1\n", 68 | " else:\n", 69 | " skipped += 1\n", 70 | "\n", 71 | "print(f\"Found {len(items)} items\")\n", 72 | "print(f\"Skipped {skipped} items without title\")\n", 73 | "\n", 74 | "# load reviews data\n", 75 | "reviews = defaultdict(list)\n", 76 | "item_freq = defaultdict(int)\n", 77 | "skipped = 0\n", 78 | "with gzip.open(reviews_path, \"r\") as f:\n", 79 | " for line in tqdm(f):\n", 80 | " json_obj = eval(line)\n", 81 | " user_id = json_obj['reviewerID']\n", 82 | " asin = json_obj['asin']\n", 83 | " timestemp = json_obj['unixReviewTime']\n", 84 | " if asin in items:\n", 85 | " reviews[user_id].append((asin, int(timestemp)))\n", 86 | " item_freq[asin] += 1\n", 87 | " else:\n", 88 | " skipped += 1\n", 89 | " # print(f\"skipped {asin}\")\n", 90 | "\n", 91 | "print(f\"Found {len(reviews)} users\")\n", 92 | "print(f\"Found {sum(item_freq.values())} reviews\")\n", 93 | "print(f\"Skipepd {skipped} item reviews without metadata\")\n", 94 | "\n", 95 | "item_freq = {k: v for k, v in item_freq.items() if v >= ITEM_FREQ_MIN}\n", 96 | "\n", 97 | "item_freq = dict(sorted(item_freq.items()))\n", 98 | "\n", 99 | "# remove user with less than K reviews\n", 100 | "removed_users_less_than = 0\n", 101 | "removed_users_item_less_than = 0\n", 102 | "removed_items = 0\n", 103 | "updated_items = set()\n", 104 | "for user_id in list(reviews.keys()):\n", 105 | " if len(reviews[user_id]) < REVIEWS_REMOVE_LESS_THAN:\n", 106 | " del reviews[user_id]\n", 107 | " removed_users_less_than += 1\n", 108 | " else:\n", 109 | " len_before = len(reviews[user_id])\n", 110 | " reviews[user_id] = [item for item in reviews[user_id] if item[0] in item_freq]\n", 111 | " updated_items.update([t[0] for t in reviews[user_id]])\n", 112 | " removed_items += len_before - len(reviews[user_id])\n", 113 | " if len(reviews[user_id]) <= 0:\n", 114 | " del reviews[user_id]\n", 115 | " removed_users_item_less_than += 1\n", 116 | "print(f\"Removed {removed_items} reviews of items that appear less than {ITEM_FREQ_MIN} in total\")\n", 117 | "print(f\"Removed {removed_users_less_than} users with less than {REVIEWS_REMOVE_LESS_THAN} actions\")\n", 118 | "print(f\"Removed {removed_users_item_less_than} users with only item count less than {REVIEWS_REMOVE_LESS_THAN}\")\n", 119 | "\n", 120 | "# calculate item frequencey again \n", 121 | "original_item_freq = item_freq\n", 122 | "item_freq = defaultdict(int)\n", 123 | "for user_id, rating_list in reviews.items():\n", 124 | " for item, timestamp in rating_list:\n", 125 | " item_freq[item] += 1\n", 126 | " \n", 127 | "item_freq = dict(sorted(item_freq.items()))\n", 128 | "print(f\"Total of {sum(item_freq.values())} reviews\")\n", 129 | "\n", 130 | "# remove \"unused\" items\n", 131 | "new_items = {}\n", 132 | "new_item_freq = {}\n", 133 | "new_original_item_freq = {}\n", 134 | "for asin in tqdm(updated_items):\n", 135 | " new_items[asin] = items[asin]\n", 136 | " new_item_freq[asin] = item_freq[asin]\n", 137 | " new_original_item_freq[asin] = original_item_freq[asin]\n", 138 | "print(f\"Removed {len(items) - len(new_items)} items that are not been reviewd\")\n", 139 | "item_freq = new_item_freq\n", 140 | "items = new_items\n", 141 | "original_item_freq = new_original_item_freq\n", 142 | "\n", 143 | "\n", 144 | "print()\n", 145 | "print(f\"Items Reviews Users\")\n", 146 | "print(f\"{len(items):<4} {sum(len(v) for v in reviews.values()):<7} {len(reviews):<5}\")\n", 147 | "\n", 148 | "# fix user id\n", 149 | "user_id_mapping = dict()\n", 150 | "i = 0\n", 151 | "for original_user_id in reviews:\n", 152 | " user_id_mapping[original_user_id] = i\n", 153 | " i += 1\n", 154 | "\n", 155 | "# fix items ids\n", 156 | "item_id_mapping = dict()\n", 157 | "i = 0\n", 158 | "for asin in items:\n", 159 | " item_id_mapping[asin] = i\n", 160 | " i += 1\n", 161 | "\n", 162 | "train_item_freq = {k: 0 for k in item_freq.keys()}\n", 163 | "val_item_freq = {k: 0 for k in item_freq.keys()}\n", 164 | "test_item_freq = {k: 0 for k in item_freq.keys()}\n", 165 | "for user_id, rating_list in reviews.items():\n", 166 | " sorted_list = list(map(lambda t: t[0], sorted(rating_list, key=lambda t: t[1])))\n", 167 | " if len(sorted_list) < 3:\n", 168 | " train_list = sorted_list\n", 169 | " else:\n", 170 | " train_list = sorted_list[1:-2]\n", 171 | " val_item_freq[sorted_list[-2]] += 1\n", 172 | " test_item_freq[sorted_list[-1]] += 1 \n", 173 | " for asin in train_list:\n", 174 | " train_item_freq[asin] += 1\n", 175 | " \n", 176 | "\n", 177 | "with open(out_path, \"w\") as f:\n", 178 | " for user_id, rating_list in reviews.items():\n", 179 | " sorted_list = sorted(rating_list, key=lambda t: t[1])\n", 180 | " for asin, timestamp in sorted_list:\n", 181 | " f.write(f\"{user_id_mapping[user_id] + 1} {item_id_mapping[asin] + 1}\\n\") # start user id from 1 to match original SASRec paper,reserve the 0 index for padding\n", 182 | "\n", 183 | "with open(id_to_title_map_path, \"w\") as f:\n", 184 | " for asin, title in items.items():\n", 185 | " f.write(f'{item_id_mapping[asin]} \"{title}\"\\n')\n", 186 | "\n", 187 | "with open(id_to_asin_map_path, \"w\") as f:\n", 188 | " for asin, item_id in item_id_mapping.items():\n", 189 | " f.write(f'{item_id} {asin}\\n')\n", 190 | "\n", 191 | "with open(train_item_freq_path, \"w\") as f:\n", 192 | " for asin, count in train_item_freq.items():\n", 193 | " f.write(f'{item_id_mapping[asin]} {count}\\n')" 194 | ] 195 | } 196 | ], 197 | "metadata": { 198 | "kernelspec": { 199 | "display_name": "py39", 200 | "language": "python", 201 | "name": "python3" 202 | }, 203 | "language_info": { 204 | "codemirror_mode": { 205 | "name": "ipython", 206 | "version": 3 207 | }, 208 | "file_extension": ".py", 209 | "mimetype": "text/x-python", 210 | "name": "python", 211 | "nbconvert_exporter": "python", 212 | "pygments_lexer": "ipython3", 213 | "version": "3.9.17" 214 | }, 215 | "orig_nbformat": 4 216 | }, 217 | "nbformat": 4, 218 | "nbformat_minor": 2 219 | } 220 | -------------------------------------------------------------------------------- /data_preprocessing/PetSupplies/preprocessing_data.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from collections import defaultdict\n", 10 | "import os\n", 11 | "import wget\n", 12 | "import gzip\n", 13 | "from tqdm.notebook import tqdm" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": null, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "base_url = \"http://snap.stanford.edu/data/amazon/productGraph/categoryFiles/\"\n", 23 | "reviews_path = \"reviews_Pet_Supplies.json.gz\"\n", 24 | "meta_path = \"meta_Pet_Supplies.json.gz\"\n", 25 | "\n", 26 | "# download raw dataset\n", 27 | "if not os.path.exists(reviews_path):\n", 28 | " wget.download(base_url + reviews_path)\n", 29 | "if not os.path.exists(meta_path):\n", 30 | " wget.download(base_url + meta_path)" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": null, 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "ITEM_FREQ_MIN = 5\n", 40 | "REVIEWS_REMOVE_LESS_THAN = 5\n", 41 | "\n", 42 | "out_path = \"Pet.txt\"\n", 43 | "id_to_title_map_path = \"Pet-titles.txt\"\n", 44 | "id_to_asin_map_path = \"Pet-id_to_asin.txt\"\n", 45 | "train_item_freq_path = \"Pet-train_item_freq.txt\"" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": null, 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "# load items data\n", 55 | "items = dict()\n", 56 | "skipped = 0\n", 57 | "with gzip.open(meta_path, \"r\") as f:\n", 58 | " for line in tqdm(f):\n", 59 | " json_obj = eval(line)\n", 60 | " asin = json_obj['asin']\n", 61 | " if 'title' in json_obj:\n", 62 | " title = json_obj['title'].replace(\"\\\"\", \"'\")\n", 63 | " title = title.replace(\"\\n\", \" \").replace('"', '\\\\\"').replace('&', '&').replace('®', '').replace('™', '').replace('é', 'e').replace('°', '').replace('<', '<').replace('>', '>').replace(' ', ' ').replace('&frac', '/')\n", 64 | " if len(title) >= 2:\n", 65 | " items[asin] = title\n", 66 | " else:\n", 67 | " skipped +=1\n", 68 | " else:\n", 69 | " skipped += 1\n", 70 | "\n", 71 | "print(f\"Found {len(items)} items\")\n", 72 | "print(f\"Skipped {skipped} items without title\")\n", 73 | "\n", 74 | "# load reviews data\n", 75 | "reviews = defaultdict(list)\n", 76 | "item_freq = defaultdict(int)\n", 77 | "skipped = 0\n", 78 | "with gzip.open(reviews_path, \"r\") as f:\n", 79 | " for line in tqdm(f):\n", 80 | " json_obj = eval(line)\n", 81 | " user_id = json_obj['reviewerID']\n", 82 | " asin = json_obj['asin']\n", 83 | " timestemp = json_obj['unixReviewTime']\n", 84 | " if asin in items:\n", 85 | " reviews[user_id].append((asin, int(timestemp)))\n", 86 | " item_freq[asin] += 1\n", 87 | " else:\n", 88 | " skipped += 1\n", 89 | " # print(f\"skipped {asin}\")\n", 90 | "\n", 91 | "print(f\"Found {len(reviews)} users\")\n", 92 | "print(f\"Found {sum(item_freq.values())} reviews\")\n", 93 | "print(f\"Skipepd {skipped} item reviews without metadata\")\n", 94 | "\n", 95 | "item_freq = {k: v for k, v in item_freq.items() if v >= ITEM_FREQ_MIN}\n", 96 | "\n", 97 | "item_freq = dict(sorted(item_freq.items()))\n", 98 | "\n", 99 | "# remove user with less than K reviews\n", 100 | "removed_users_less_than = 0\n", 101 | "removed_users_item_less_than = 0\n", 102 | "removed_items = 0\n", 103 | "updated_items = set()\n", 104 | "for user_id in list(reviews.keys()):\n", 105 | " if len(reviews[user_id]) < REVIEWS_REMOVE_LESS_THAN:\n", 106 | " del reviews[user_id]\n", 107 | " removed_users_less_than += 1\n", 108 | " else:\n", 109 | " len_before = len(reviews[user_id])\n", 110 | " reviews[user_id] = [item for item in reviews[user_id] if item[0] in item_freq]\n", 111 | " updated_items.update([t[0] for t in reviews[user_id]])\n", 112 | " removed_items += len_before - len(reviews[user_id])\n", 113 | " if len(reviews[user_id]) <= 0:\n", 114 | " del reviews[user_id]\n", 115 | " removed_users_item_less_than += 1\n", 116 | "print(f\"Removed {removed_items} reviews of items that appear less than {ITEM_FREQ_MIN} in total\")\n", 117 | "print(f\"Removed {removed_users_less_than} users with less than {REVIEWS_REMOVE_LESS_THAN} actions\")\n", 118 | "print(f\"Removed {removed_users_item_less_than} users with only item count less than {REVIEWS_REMOVE_LESS_THAN}\")\n", 119 | "\n", 120 | "# calculate item frequencey again \n", 121 | "original_item_freq = item_freq\n", 122 | "item_freq = defaultdict(int)\n", 123 | "for user_id, rating_list in reviews.items():\n", 124 | " for item, timestamp in rating_list:\n", 125 | " item_freq[item] += 1\n", 126 | " \n", 127 | "item_freq = dict(sorted(item_freq.items()))\n", 128 | "print(f\"Total of {sum(item_freq.values())} reviews\")\n", 129 | "\n", 130 | "# remove \"unused\" items\n", 131 | "new_items = {}\n", 132 | "new_item_freq = {}\n", 133 | "new_original_item_freq = {}\n", 134 | "for asin in tqdm(updated_items):\n", 135 | " new_items[asin] = items[asin]\n", 136 | " new_item_freq[asin] = item_freq[asin]\n", 137 | " new_original_item_freq[asin] = original_item_freq[asin]\n", 138 | "print(f\"Removed {len(items) - len(new_items)} items that are not been reviewd\")\n", 139 | "item_freq = new_item_freq\n", 140 | "items = new_items\n", 141 | "original_item_freq = new_original_item_freq\n", 142 | "\n", 143 | "\n", 144 | "print()\n", 145 | "print(f\"Items Reviews Users\")\n", 146 | "print(f\"{len(items):<4} {sum(len(v) for v in reviews.values()):<7} {len(reviews):<5}\")\n", 147 | "\n", 148 | "# fix user id\n", 149 | "user_id_mapping = dict()\n", 150 | "i = 0\n", 151 | "for original_user_id in reviews:\n", 152 | " user_id_mapping[original_user_id] = i\n", 153 | " i += 1\n", 154 | "\n", 155 | "# fix items ids\n", 156 | "item_id_mapping = dict()\n", 157 | "i = 0\n", 158 | "for asin in items:\n", 159 | " item_id_mapping[asin] = i\n", 160 | " i += 1\n", 161 | "\n", 162 | "train_item_freq = {k: 0 for k in item_freq.keys()}\n", 163 | "val_item_freq = {k: 0 for k in item_freq.keys()}\n", 164 | "test_item_freq = {k: 0 for k in item_freq.keys()}\n", 165 | "for user_id, rating_list in reviews.items():\n", 166 | " sorted_list = list(map(lambda t: t[0], sorted(rating_list, key=lambda t: t[1])))\n", 167 | " if len(sorted_list) < 3:\n", 168 | " train_list = sorted_list\n", 169 | " else:\n", 170 | " train_list = sorted_list[1:-2]\n", 171 | " val_item_freq[sorted_list[-2]] += 1\n", 172 | " test_item_freq[sorted_list[-1]] += 1 \n", 173 | " for asin in train_list:\n", 174 | " train_item_freq[asin] += 1\n", 175 | " \n", 176 | "\n", 177 | "with open(out_path, \"w\") as f:\n", 178 | " for user_id, rating_list in reviews.items():\n", 179 | " sorted_list = sorted(rating_list, key=lambda t: t[1])\n", 180 | " for asin, timestamp in sorted_list:\n", 181 | " f.write(f\"{user_id_mapping[user_id] + 1} {item_id_mapping[asin] + 1}\\n\") # start user id from 1 to match original SASRec paper,reserve the 0 index for padding\n", 182 | "\n", 183 | "with open(id_to_title_map_path, \"w\") as f:\n", 184 | " for asin, title in items.items():\n", 185 | " f.write(f'{item_id_mapping[asin]} \"{title}\"\\n')\n", 186 | "\n", 187 | "with open(id_to_asin_map_path, \"w\") as f:\n", 188 | " for asin, item_id in item_id_mapping.items():\n", 189 | " f.write(f'{item_id} {asin}\\n')\n", 190 | "\n", 191 | "with open(train_item_freq_path, \"w\") as f:\n", 192 | " for asin, count in train_item_freq.items():\n", 193 | " f.write(f'{item_id_mapping[asin]} {count}\\n')" 194 | ] 195 | } 196 | ], 197 | "metadata": { 198 | "kernelspec": { 199 | "display_name": "py39", 200 | "language": "python", 201 | "name": "python3" 202 | }, 203 | "language_info": { 204 | "codemirror_mode": { 205 | "name": "ipython", 206 | "version": 3 207 | }, 208 | "file_extension": ".py", 209 | "mimetype": "text/x-python", 210 | "name": "python", 211 | "nbconvert_exporter": "python", 212 | "pygments_lexer": "ipython3", 213 | "version": "3.9.13" 214 | }, 215 | "orig_nbformat": 4 216 | }, 217 | "nbformat": 4, 218 | "nbformat_minor": 2 219 | } 220 | -------------------------------------------------------------------------------- /data_preprocessing/HomeKitchen/preprocessing_data.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from collections import defaultdict\n", 10 | "import os\n", 11 | "import wget\n", 12 | "import gzip\n", 13 | "from tqdm.notebook import tqdm" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": null, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "base_url = \"http://snap.stanford.edu/data/amazon/productGraph/categoryFiles/\"\n", 23 | "reviews_path = \"reviews_Home_and_Kitchen.json.gz\"\n", 24 | "meta_path = \"meta_Home_and_Kitchen.json.gz\"\n", 25 | "\n", 26 | "# download raw dataset\n", 27 | "if not os.path.exists(reviews_path):\n", 28 | " wget.download(base_url + reviews_path)\n", 29 | "if not os.path.exists(meta_path):\n", 30 | " wget.download(base_url + meta_path)" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": null, 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "ITEM_FREQ_MIN = 5\n", 40 | "REVIEWS_REMOVE_LESS_THAN = 5\n", 41 | "\n", 42 | "out_path = \"HomeKitchen.txt\"\n", 43 | "id_to_title_map_path = \"HomeKitchen-titles.txt\"\n", 44 | "id_to_asin_map_path = \"HomeKitchen-id_to_asin.txt\"\n", 45 | "train_item_freq_path = \"HomeKitchen-train_item_freq.txt\"" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": null, 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "# load items data\n", 55 | "items = dict()\n", 56 | "skipped = 0\n", 57 | "with gzip.open(meta_path, \"r\") as f:\n", 58 | " for line in tqdm(f):\n", 59 | " json_obj = eval(line)\n", 60 | " asin = json_obj['asin']\n", 61 | " if 'title' in json_obj:\n", 62 | " title = json_obj['title'].replace(\"\\\"\", \"'\")\n", 63 | " title = title.replace(\"\\n\", \" \").replace('"', '\\\\\"').replace('&', '&').replace('®', '').replace('™', '').replace('é', 'e').replace('°', '').replace('<', '<').replace('>', '>').replace(' ', ' ').replace('&frac', '/')\n", 64 | " if len(title) >= 2:\n", 65 | " items[asin] = title\n", 66 | " else:\n", 67 | " skipped +=1\n", 68 | " else:\n", 69 | " skipped += 1\n", 70 | "\n", 71 | "print(f\"Found {len(items)} items\")\n", 72 | "print(f\"Skipped {skipped} items without title\")\n", 73 | "\n", 74 | "# load reviews data\n", 75 | "reviews = defaultdict(list)\n", 76 | "item_freq = defaultdict(int)\n", 77 | "skipped = 0\n", 78 | "with gzip.open(reviews_path, \"r\") as f:\n", 79 | " for line in tqdm(f):\n", 80 | " json_obj = eval(line)\n", 81 | " user_id = json_obj['reviewerID']\n", 82 | " asin = json_obj['asin']\n", 83 | " timestemp = json_obj['unixReviewTime']\n", 84 | " if asin in items:\n", 85 | " reviews[user_id].append((asin, int(timestemp)))\n", 86 | " item_freq[asin] += 1\n", 87 | " else:\n", 88 | " skipped += 1\n", 89 | " # print(f\"skipped {asin}\")\n", 90 | "\n", 91 | "print(f\"Found {len(reviews)} users\")\n", 92 | "print(f\"Found {sum(item_freq.values())} reviews\")\n", 93 | "print(f\"Skipepd {skipped} item reviews without metadata\")\n", 94 | "\n", 95 | "item_freq = {k: v for k, v in item_freq.items() if v >= ITEM_FREQ_MIN}\n", 96 | "\n", 97 | "item_freq = dict(sorted(item_freq.items()))\n", 98 | "\n", 99 | "# remove user with less than K reviews\n", 100 | "removed_users_less_than = 0\n", 101 | "removed_users_item_less_than = 0\n", 102 | "removed_items = 0\n", 103 | "updated_items = set()\n", 104 | "for user_id in list(reviews.keys()):\n", 105 | " if len(reviews[user_id]) < REVIEWS_REMOVE_LESS_THAN:\n", 106 | " del reviews[user_id]\n", 107 | " removed_users_less_than += 1\n", 108 | " else:\n", 109 | " len_before = len(reviews[user_id])\n", 110 | " reviews[user_id] = [item for item in reviews[user_id] if item[0] in item_freq]\n", 111 | " updated_items.update([t[0] for t in reviews[user_id]])\n", 112 | " removed_items += len_before - len(reviews[user_id])\n", 113 | " if len(reviews[user_id]) <= 0:\n", 114 | " del reviews[user_id]\n", 115 | " removed_users_item_less_than += 1\n", 116 | "print(f\"Removed {removed_items} reviews of items that appear less than {ITEM_FREQ_MIN} in total\")\n", 117 | "print(f\"Removed {removed_users_less_than} users with less than {REVIEWS_REMOVE_LESS_THAN} actions\")\n", 118 | "print(f\"Removed {removed_users_item_less_than} users with only item count less than {REVIEWS_REMOVE_LESS_THAN}\")\n", 119 | "\n", 120 | "# calculate item frequencey again \n", 121 | "original_item_freq = item_freq\n", 122 | "item_freq = defaultdict(int)\n", 123 | "for user_id, rating_list in reviews.items():\n", 124 | " for item, timestamp in rating_list:\n", 125 | " item_freq[item] += 1\n", 126 | " \n", 127 | "item_freq = dict(sorted(item_freq.items()))\n", 128 | "print(f\"Total of {sum(item_freq.values())} reviews\")\n", 129 | "\n", 130 | "# remove \"unused\" items\n", 131 | "new_items = {}\n", 132 | "new_item_freq = {}\n", 133 | "new_original_item_freq = {}\n", 134 | "for asin in tqdm(updated_items):\n", 135 | " new_items[asin] = items[asin]\n", 136 | " new_item_freq[asin] = item_freq[asin]\n", 137 | " new_original_item_freq[asin] = original_item_freq[asin]\n", 138 | "print(f\"Removed {len(items) - len(new_items)} items that are not been reviewd\")\n", 139 | "item_freq = new_item_freq\n", 140 | "items = new_items\n", 141 | "original_item_freq = new_original_item_freq\n", 142 | "\n", 143 | "\n", 144 | "print()\n", 145 | "print(f\"Items Reviews Users\")\n", 146 | "print(f\"{len(items):<4} {sum(len(v) for v in reviews.values()):<7} {len(reviews):<5}\")\n", 147 | "\n", 148 | "# fix user id\n", 149 | "user_id_mapping = dict()\n", 150 | "i = 0\n", 151 | "for original_user_id in reviews:\n", 152 | " user_id_mapping[original_user_id] = i\n", 153 | " i += 1\n", 154 | "\n", 155 | "# fix items ids\n", 156 | "item_id_mapping = dict()\n", 157 | "i = 0\n", 158 | "for asin in items:\n", 159 | " item_id_mapping[asin] = i\n", 160 | " i += 1\n", 161 | "\n", 162 | "train_item_freq = {k: 0 for k in item_freq.keys()}\n", 163 | "val_item_freq = {k: 0 for k in item_freq.keys()}\n", 164 | "test_item_freq = {k: 0 for k in item_freq.keys()}\n", 165 | "for user_id, rating_list in reviews.items():\n", 166 | " sorted_list = list(map(lambda t: t[0], sorted(rating_list, key=lambda t: t[1])))\n", 167 | " if len(sorted_list) < 3:\n", 168 | " train_list = sorted_list\n", 169 | " else:\n", 170 | " train_list = sorted_list[1:-2]\n", 171 | " val_item_freq[sorted_list[-2]] += 1\n", 172 | " test_item_freq[sorted_list[-1]] += 1 \n", 173 | " for asin in train_list:\n", 174 | " train_item_freq[asin] += 1\n", 175 | " \n", 176 | "\n", 177 | "with open(out_path, \"w\") as f:\n", 178 | " for user_id, rating_list in reviews.items():\n", 179 | " sorted_list = sorted(rating_list, key=lambda t: t[1])\n", 180 | " for asin, timestamp in sorted_list:\n", 181 | " f.write(f\"{user_id_mapping[user_id] + 1} {item_id_mapping[asin] + 1}\\n\") # start user id from 1 to match original SASRec paper,reserve the 0 index for padding\n", 182 | "\n", 183 | "with open(id_to_title_map_path, \"w\") as f:\n", 184 | " for asin, title in items.items():\n", 185 | " f.write(f'{item_id_mapping[asin]} \"{title}\"\\n')\n", 186 | "\n", 187 | "with open(id_to_asin_map_path, \"w\") as f:\n", 188 | " for asin, item_id in item_id_mapping.items():\n", 189 | " f.write(f'{item_id} {asin}\\n')\n", 190 | "\n", 191 | "with open(train_item_freq_path, \"w\") as f:\n", 192 | " for asin, count in train_item_freq.items():\n", 193 | " f.write(f'{item_id_mapping[asin]} {count}\\n')" 194 | ] 195 | } 196 | ], 197 | "metadata": { 198 | "kernelspec": { 199 | "display_name": "py39", 200 | "language": "python", 201 | "name": "python3" 202 | }, 203 | "language_info": { 204 | "codemirror_mode": { 205 | "name": "ipython", 206 | "version": 3 207 | }, 208 | "file_extension": ".py", 209 | "mimetype": "text/x-python", 210 | "name": "python", 211 | "nbconvert_exporter": "python", 212 | "pygments_lexer": "ipython3", 213 | "version": "3.10.12" 214 | }, 215 | "orig_nbformat": 4 216 | }, 217 | "nbformat": 4, 218 | "nbformat_minor": 2 219 | } 220 | -------------------------------------------------------------------------------- /data_preprocessing/Tools/preprocessing_data.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from collections import defaultdict\n", 10 | "import os\n", 11 | "import wget\n", 12 | "import gzip\n", 13 | "from tqdm.notebook import tqdm" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": null, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "base_url = \"http://snap.stanford.edu/data/amazon/productGraph/categoryFiles/\"\n", 23 | "reviews_path = \"reviews_Tools_and_Home_Improvement.json.gz\"\n", 24 | "meta_path = \"meta_Tools_and_Home_Improvement.json.gz\"\n", 25 | "\n", 26 | "# download raw dataset\n", 27 | "if not os.path.exists(reviews_path):\n", 28 | " wget.download(base_url + reviews_path)\n", 29 | "if not os.path.exists(meta_path):\n", 30 | " wget.download(base_url + meta_path)" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": null, 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "ITEM_FREQ_MIN = 5\n", 40 | "REVIEWS_REMOVE_LESS_THAN = 5\n", 41 | "\n", 42 | "out_path = \"Tools.txt\"\n", 43 | "id_to_title_map_path = \"Tools-titles.txt\"\n", 44 | "id_to_asin_map_path = \"Tools-id_to_asin.txt\"\n", 45 | "train_item_freq_path = \"Tools-train_item_freq.txt\"" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": null, 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "# load items data\n", 55 | "items = dict()\n", 56 | "skipped = 0\n", 57 | "with gzip.open(meta_path, \"r\") as f:\n", 58 | " for line in tqdm(f):\n", 59 | " json_obj = eval(line)\n", 60 | " asin = json_obj['asin']\n", 61 | " if 'title' in json_obj:\n", 62 | " title = json_obj['title'].replace(\"\\\"\", \"'\")\n", 63 | " title = title.replace(\"\\n\", \" \").replace('"', '\\\\\"').replace('&', '&').replace('®', '').replace('™', '').replace('é', 'e').replace('°', '').replace('<', '<').replace('>', '>').replace(' ', ' ').replace('&frac', '/')\n", 64 | " if len(title) >= 2:\n", 65 | " items[asin] = title\n", 66 | " else:\n", 67 | " skipped +=1\n", 68 | " else:\n", 69 | " skipped += 1\n", 70 | "\n", 71 | "print(f\"Found {len(items)} items\")\n", 72 | "print(f\"Skipped {skipped} items without title\")\n", 73 | "\n", 74 | "# load reviews data\n", 75 | "reviews = defaultdict(list)\n", 76 | "item_freq = defaultdict(int)\n", 77 | "skipped = 0\n", 78 | "with gzip.open(reviews_path, \"r\") as f:\n", 79 | " for line in tqdm(f):\n", 80 | " json_obj = eval(line)\n", 81 | " user_id = json_obj['reviewerID']\n", 82 | " asin = json_obj['asin']\n", 83 | " timestemp = json_obj['unixReviewTime']\n", 84 | " if asin in items:\n", 85 | " reviews[user_id].append((asin, int(timestemp)))\n", 86 | " item_freq[asin] += 1\n", 87 | " else:\n", 88 | " skipped += 1\n", 89 | " # print(f\"skipped {asin}\")\n", 90 | "\n", 91 | "print(f\"Found {len(reviews)} users\")\n", 92 | "print(f\"Found {sum(item_freq.values())} reviews\")\n", 93 | "print(f\"Skipepd {skipped} item reviews without metadata\")\n", 94 | "\n", 95 | "\n", 96 | "item_freq = {k: v for k, v in item_freq.items() if v >= ITEM_FREQ_MIN}\n", 97 | "\n", 98 | "item_freq = dict(sorted(item_freq.items()))\n", 99 | "\n", 100 | "# remove user with less than K reviews\n", 101 | "removed_users_less_than = 0\n", 102 | "removed_users_item_less_than = 0\n", 103 | "removed_items = 0\n", 104 | "updated_items = set()\n", 105 | "for user_id in list(reviews.keys()):\n", 106 | " if len(reviews[user_id]) < REVIEWS_REMOVE_LESS_THAN:\n", 107 | " del reviews[user_id]\n", 108 | " removed_users_less_than += 1\n", 109 | " else:\n", 110 | " len_before = len(reviews[user_id])\n", 111 | " reviews[user_id] = [item for item in reviews[user_id] if item[0] in item_freq]\n", 112 | " updated_items.update([t[0] for t in reviews[user_id]])\n", 113 | " removed_items += len_before - len(reviews[user_id])\n", 114 | " if len(reviews[user_id]) <= 0:\n", 115 | " del reviews[user_id]\n", 116 | " removed_users_item_less_than += 1\n", 117 | "print(f\"Removed {removed_items} reviews of items that appear less than {ITEM_FREQ_MIN} in total\")\n", 118 | "print(f\"Removed {removed_users_less_than} users with less than {REVIEWS_REMOVE_LESS_THAN} actions\")\n", 119 | "print(f\"Removed {removed_users_item_less_than} users with only item count less than {REVIEWS_REMOVE_LESS_THAN}\")\n", 120 | "\n", 121 | "# calculate item frequencey again \n", 122 | "original_item_freq = item_freq\n", 123 | "item_freq = defaultdict(int)\n", 124 | "for user_id, rating_list in reviews.items():\n", 125 | " for item, timestamp in rating_list:\n", 126 | " item_freq[item] += 1\n", 127 | " \n", 128 | "item_freq = dict(sorted(item_freq.items()))\n", 129 | "print(f\"Total of {sum(item_freq.values())} reviews\")\n", 130 | "\n", 131 | "# remove \"unused\" items\n", 132 | "new_items = {}\n", 133 | "new_item_freq = {}\n", 134 | "new_original_item_freq = {}\n", 135 | "for asin in tqdm(updated_items):\n", 136 | " new_items[asin] = items[asin]\n", 137 | " new_item_freq[asin] = item_freq[asin]\n", 138 | " new_original_item_freq[asin] = original_item_freq[asin]\n", 139 | "print(f\"Removed {len(items) - len(new_items)} items that are not been reviewd\")\n", 140 | "item_freq = new_item_freq\n", 141 | "items = new_items\n", 142 | "original_item_freq = new_original_item_freq\n", 143 | "\n", 144 | "\n", 145 | "print()\n", 146 | "print(f\"Items Reviews Users\")\n", 147 | "print(f\"{len(items):<4} {sum(len(v) for v in reviews.values()):<7} {len(reviews):<5}\")\n", 148 | "\n", 149 | "# fix user id\n", 150 | "user_id_mapping = dict()\n", 151 | "i = 0\n", 152 | "for original_user_id in reviews:\n", 153 | " user_id_mapping[original_user_id] = i\n", 154 | " i += 1\n", 155 | "\n", 156 | "# fix items ids\n", 157 | "item_id_mapping = dict()\n", 158 | "i = 0\n", 159 | "for asin in items:\n", 160 | " item_id_mapping[asin] = i\n", 161 | " i += 1\n", 162 | "\n", 163 | "train_item_freq = {k: 0 for k in item_freq.keys()}\n", 164 | "val_item_freq = {k: 0 for k in item_freq.keys()}\n", 165 | "test_item_freq = {k: 0 for k in item_freq.keys()}\n", 166 | "for user_id, rating_list in reviews.items():\n", 167 | " sorted_list = list(map(lambda t: t[0], sorted(rating_list, key=lambda t: t[1])))\n", 168 | " if len(sorted_list) < 3:\n", 169 | " train_list = sorted_list\n", 170 | " else:\n", 171 | " train_list = sorted_list[1:-2]\n", 172 | " val_item_freq[sorted_list[-2]] += 1\n", 173 | " test_item_freq[sorted_list[-1]] += 1 \n", 174 | " for asin in train_list:\n", 175 | " train_item_freq[asin] += 1\n", 176 | " \n", 177 | "\n", 178 | "with open(out_path, \"w\") as f:\n", 179 | " for user_id, rating_list in reviews.items():\n", 180 | " sorted_list = sorted(rating_list, key=lambda t: t[1])\n", 181 | " for asin, timestamp in sorted_list:\n", 182 | " f.write(f\"{user_id_mapping[user_id] + 1} {item_id_mapping[asin] + 1}\\n\") # start user id from 1 to match original SASRec paper,reserve the 0 index for padding\n", 183 | "\n", 184 | "with open(id_to_title_map_path, \"w\") as f:\n", 185 | " for asin, title in items.items():\n", 186 | " f.write(f'{item_id_mapping[asin]} \"{title}\"\\n')\n", 187 | "\n", 188 | "with open(id_to_asin_map_path, \"w\") as f:\n", 189 | " for asin, item_id in item_id_mapping.items():\n", 190 | " f.write(f'{item_id} {asin}\\n')\n", 191 | "\n", 192 | "with open(train_item_freq_path, \"w\") as f:\n", 193 | " for asin, count in train_item_freq.items():\n", 194 | " f.write(f'{item_id_mapping[asin]} {count}\\n')" 195 | ] 196 | } 197 | ], 198 | "metadata": { 199 | "kernelspec": { 200 | "display_name": "py39", 201 | "language": "python", 202 | "name": "python3" 203 | }, 204 | "language_info": { 205 | "codemirror_mode": { 206 | "name": "ipython", 207 | "version": 3 208 | }, 209 | "file_extension": ".py", 210 | "mimetype": "text/x-python", 211 | "name": "python", 212 | "nbconvert_exporter": "python", 213 | "pygments_lexer": "ipython3", 214 | "version": "3.10.12" 215 | }, 216 | "orig_nbformat": 4 217 | }, 218 | "nbformat": 4, 219 | "nbformat_minor": 2 220 | } 221 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | -------------------------------------------------------------------------------- /data_preprocessing/calculate_similarity_scores.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import pandas as pd\n", 10 | "import torch\n", 11 | "from tqdm.notebook import tqdm\n", 12 | "import numpy as np\n", 13 | "from sentence_transformers import SentenceTransformer" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": null, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "EMBEDDING_MODEL = 'thenlper/gte-large'\n", 23 | "DELIMITER=\" \"\n", 24 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", 25 | "batch_size = 128\n", 26 | "K = 1000" 27 | ] 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "metadata": {}, 32 | "source": [ 33 | "## Datasets\n", 34 | "\n", 35 | "Uncomment the dataset you want work on." 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": null, 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "# # ml-1m\n", 45 | "# def ml_preprocessing(title):\n", 46 | "# title = \" \".join(title.split(\" \")[:-1]).strip()\n", 47 | "# if title.endswith(\", The\"):\n", 48 | "# title = \"The \" + title[:-5] \n", 49 | "# if title.endswith(\", A\"):\n", 50 | "# title = \"A \" + title[:-3] \n", 51 | "# return title\n", 52 | "\n", 53 | "# data_path = \"ML-1M/ml-1m.txt\"\n", 54 | "# titles_path = \"ML-1M/ml-1m-titles.txt\"\n", 55 | "# title_freq_path = \"ML-1M/ml-1m-train_item_freq.txt\"\n", 56 | "# similarity_indices_out = f\"ML-1M/ml-1m-similarity-indices-{EMBEDDING_MODEL.replace('/','_')}.pt\"\n", 57 | "# similarity_values_out = f\"ML-1M/ml-1m-similarity-values-{EMBEDDING_MODEL.replace('/','_')}.pt\"\n", 58 | "# embeddings_out = f\"ML-1M/ml-1m-embeddings-{EMBEDDING_MODEL.replace('/','_')}.pt\"\n", 59 | "# timestamp_path = \"ML-1M/ml-1m_timestamp.txt\"\n", 60 | "# preprocessing_title = ml_preprocessing" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": null, 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "# Beauty\n", 70 | "data_path = \"Beauty/Beauty.txt\"\n", 71 | "titles_path = \"Beauty/Beauty-titles.txt\"\n", 72 | "title_freq_path = \"Beauty/Beauty-train_item_freq.txt\"\n", 73 | "similarity_indices_out = f\"Beauty/Beauty-similarity-indices-{EMBEDDING_MODEL.replace('/','_')}.pt\"\n", 74 | "similarity_values_out = f\"Beauty/Beauty-similarity-values-{EMBEDDING_MODEL.replace('/','_')}.pt\"\n", 75 | "timestamp_path = f\"Beauty/Beauty-{EMBEDDING_MODEL.replace('/','_')}_timestamp.txt\"\n", 76 | "preprocessing_title = lambda t: t" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": null, 82 | "metadata": {}, 83 | "outputs": [], 84 | "source": [ 85 | "# # Tools\n", 86 | "# data_path = \"Tools/Tools.txt\"\n", 87 | "# titles_path = \"Tools/Tools-titles.txt\"\n", 88 | "# title_freq_path = \"Tools/Tools-train_item_freq.txt\"\n", 89 | "# similarity_indices_out = f\"Tools/Tools-similarity-indices-{EMBEDDING_MODEL.replace('/','_')}.pt\"\n", 90 | "# similarity_values_out = f\"Tools/Tools-similarity-values-{EMBEDDING_MODEL.replace('/','_')}.pt\"\n", 91 | "# timestamp_path = \"Tools/Tools_timestamp.txt\"\n", 92 | "# preprocessing_title = lambda t: t" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": null, 98 | "metadata": {}, 99 | "outputs": [], 100 | "source": [ 101 | "# # HomeKitchen\n", 102 | "# data_path = \"HomeKitchen/HomeKitchen.txt\"\n", 103 | "# titles_path = \"HomeKitchen/HomeKitchen-titles.txt\"\n", 104 | "# title_freq_path = \"HomeKitchen/HomeKitchen-train_item_freq.txt\"\n", 105 | "# similarity_indices_out = f\"HomeKitchen/HomeKitchen-similarity-indices-{EMBEDDING_MODEL.replace('/','_')}.pt\"\n", 106 | "# similarity_values_out = f\"HomeKitchen/HomeKitchen-similarity-values-{EMBEDDING_MODEL.replace('/','_')}.pt\"\n", 107 | "# timestamp_path = \"HomeKitchen/HomeKitchen_timestamp.txt\"\n", 108 | "# preprocessing_title = lambda t: t" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": null, 114 | "metadata": {}, 115 | "outputs": [], 116 | "source": [ 117 | "# # Pet Supplies\n", 118 | "# data_path = \"PetSupplies/Pet.txt\"\n", 119 | "# titles_path = \"PetSupplies/Pet-titles.txt\"\n", 120 | "# title_freq_path = \"PetSupplies/Pet-train_item_freq.txt\"\n", 121 | "# similarity_indices_out = f\"PetSupplies/Pet-similarity-indices-{EMBEDDING_MODEL.replace('/','_')}.pt\"\n", 122 | "# similarity_values_out = f\"PetSupplies/Pet-similarity-values-{EMBEDDING_MODEL.replace('/','_')}.pt\"\n", 123 | "# timestamp_path = \"PetSupplies/Pet_timestamp.txt\"\n", 124 | "# preprocessing_title = lambda t: t" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": null, 130 | "metadata": {}, 131 | "outputs": [], 132 | "source": [ 133 | "# # Steam\n", 134 | "# data_path = \"Steam/steam.txt\"\n", 135 | "# titles_path = \"Steam/steam-titles.txt\"\n", 136 | "# title_freq_path = \"Steam/steam-train_item_freq.txt\"\n", 137 | "# similarity_indices_out = f\"Steam/steam-similarity-indices-{EMBEDDING_MODEL.replace('/','_')}.pt\"\n", 138 | "# similarity_values_out = f\"Steam/steam-similarity-values-{EMBEDDING_MODEL.replace('/','_')}.pt\"\n", 139 | "# timestamp_path = f\"Steam/steam_timestamp.txt\"\n", 140 | "# preprocessing_title = lambda t: t" 141 | ] 142 | }, 143 | { 144 | "cell_type": "markdown", 145 | "metadata": {}, 146 | "source": [ 147 | "## Calcualte Similarities" 148 | ] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "execution_count": null, 153 | "metadata": {}, 154 | "outputs": [], 155 | "source": [ 156 | "def sentence_transformer(model_name, batch_size, device):\n", 157 | " model = SentenceTransformer(model_name, device=device)\n", 158 | " def embed(sentences):\n", 159 | " embeddings = []\n", 160 | " batches = [sentences[x:x+batch_size] for x in range(0, len(sentences), batch_size)]\n", 161 | " for batch in tqdm(batches):\n", 162 | " embeddings.append(model.encode(batch, convert_to_numpy=False, convert_to_tensor=True))\n", 163 | " return torch.cat(embeddings, dim=0)\n", 164 | " return embed\n", 165 | "\n", 166 | "embedding_func = sentence_transformer(model_name=EMBEDDING_MODEL, batch_size=batch_size, device=device)" 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": null, 172 | "metadata": {}, 173 | "outputs": [], 174 | "source": [ 175 | "titles_df = pd.read_csv(titles_path, names=['id', 'title'], delimiter=DELIMITER, escapechar=\"\\\\\")\n", 176 | "titles_df" 177 | ] 178 | }, 179 | { 180 | "cell_type": "code", 181 | "execution_count": null, 182 | "metadata": {}, 183 | "outputs": [], 184 | "source": [ 185 | "id_to_freq_df = pd.read_csv(title_freq_path, names=['id', 'freq'], delimiter=DELIMITER)\n", 186 | "id_to_freq_series = pd.Series(id_to_freq_df.freq.values, index=id_to_freq_df.id)\n", 187 | "id_to_freq = id_to_freq_series.to_dict()\n", 188 | "titles_df['freq'] = id_to_freq_series\n", 189 | "titles_df = titles_df[['id', 'freq', 'title']]\n", 190 | "titles_df" 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": null, 196 | "metadata": {}, 197 | "outputs": [], 198 | "source": [ 199 | "titles_df['title'] = titles_df['title'].apply(np.vectorize(preprocessing_title))\n", 200 | "titles_df" 201 | ] 202 | }, 203 | { 204 | "cell_type": "code", 205 | "execution_count": null, 206 | "metadata": {}, 207 | "outputs": [], 208 | "source": [ 209 | "titles_list = titles_df['title'].tolist()\n", 210 | "titles_embeddings = embedding_func(titles_list)\n", 211 | "titles_embeddings" 212 | ] 213 | }, 214 | { 215 | "cell_type": "code", 216 | "execution_count": null, 217 | "metadata": {}, 218 | "outputs": [], 219 | "source": [ 220 | "titles_embeddings.shape" 221 | ] 222 | }, 223 | { 224 | "cell_type": "code", 225 | "execution_count": null, 226 | "metadata": {}, 227 | "outputs": [], 228 | "source": [ 229 | "def get_similarity_matrix(emebddings, eps=1e-8, top_k=None):\n", 230 | " embeddings_norm = emebddings.norm(dim=1).unsqueeze(dim=1) # (num_embeddings, 1)\n", 231 | " embeddings_normalized = emebddings / torch.max(embeddings_norm, eps * torch.ones_like(embeddings_norm))\n", 232 | " if top_k is None:\n", 233 | " similarity_values = embeddings_normalized @ embeddings_normalized.T\n", 234 | " # fix numerical percison issues - where similarity_matrix[i,i] < similarity_matrix[i, k != i]\n", 235 | " similarity_values += torch.diag(torch.full((similarity_values.shape[0],), 1e-7, device=device))\n", 236 | " similarity_indices = torch.arange(similarity_values.shape[0]).unsqueeze(dim=0).repeat(similarity_values.shape[0], 1)\n", 237 | "\n", 238 | " else:\n", 239 | " n_embeddings = emebddings.shape[0]\n", 240 | " chunks = n_embeddings // 1000\n", 241 | " value_list = []\n", 242 | " indices_list = []\n", 243 | " for chunk in embeddings_normalized.chunk(chunks):\n", 244 | " similarity_out = chunk @ embeddings_normalized.T \n", 245 | " values, indices = torch.topk(similarity_out, dim= -1, k=top_k, sorted=True)\n", 246 | " value_list.append(values)\n", 247 | " indices_list.append(indices)\n", 248 | " similarity_values = torch.cat(value_list, dim=0)\n", 249 | " similarity_indices = torch.cat(indices_list, dim=0)\n", 250 | "\n", 251 | " return similarity_values, similarity_indices" 252 | ] 253 | }, 254 | { 255 | "cell_type": "code", 256 | "execution_count": null, 257 | "metadata": {}, 258 | "outputs": [], 259 | "source": [ 260 | "similarity_values, similarity_indices = get_similarity_matrix(titles_embeddings, top_k=K)\n", 261 | "print(similarity_indices)\n", 262 | "similarity_values" 263 | ] 264 | }, 265 | { 266 | "cell_type": "markdown", 267 | "metadata": {}, 268 | "source": [ 269 | "Save all embeddings and similarities" 270 | ] 271 | }, 272 | { 273 | "cell_type": "code", 274 | "execution_count": null, 275 | "metadata": {}, 276 | "outputs": [], 277 | "source": [ 278 | "torch.save(similarity_indices, similarity_indices_out)\n", 279 | "torch.save(similarity_values, similarity_values_out)" 280 | ] 281 | }, 282 | { 283 | "cell_type": "code", 284 | "execution_count": null, 285 | "metadata": {}, 286 | "outputs": [], 287 | "source": [ 288 | "! echo `date +'%I_%M_%d_%m'` > {timestamp_path}" 289 | ] 290 | } 291 | ], 292 | "metadata": { 293 | "kernelspec": { 294 | "display_name": "py39", 295 | "language": "python", 296 | "name": "python3" 297 | }, 298 | "language_info": { 299 | "codemirror_mode": { 300 | "name": "ipython", 301 | "version": 3 302 | }, 303 | "file_extension": ".py", 304 | "mimetype": "text/x-python", 305 | "name": "python", 306 | "nbconvert_exporter": "python", 307 | "pygments_lexer": "ipython3", 308 | "version": "3.9.17" 309 | }, 310 | "orig_nbformat": 4 311 | }, 312 | "nbformat": 4, 313 | "nbformat_minor": 2 314 | } 315 | -------------------------------------------------------------------------------- /SimRec/utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import copy 3 | import torch 4 | import random 5 | import numpy as np 6 | from collections import defaultdict 7 | from multiprocessing import Process, Queue 8 | from enum import auto, Enum 9 | from itertools import chain 10 | 11 | class LinearScheduleWithWarmup: 12 | def __init__(self, lambd, warmup_steps, lamb_steps): 13 | self.lambd = 0 14 | self.warmup_steps = warmup_steps 15 | self.lamb_steps = lamb_steps 16 | self.warmup_alpha = lambd / warmup_steps 17 | self.alpha = lambd / (warmup_steps - lamb_steps) 18 | self.bias = lambd * (1 - (warmup_steps / (warmup_steps - lamb_steps))) 19 | self.current_step = -1 20 | self.step() 21 | 22 | def get_lambd(self): 23 | return max(self.lambd, 0) 24 | 25 | def step(self): 26 | self.current_step += 1 27 | if self.current_step < self.warmup_steps: 28 | self.lambd = self.warmup_alpha * self.current_step 29 | else: 30 | self.lambd = self.alpha * self.current_step + self.bias 31 | 32 | class NoneSchedule: 33 | def __init__(self, lambd): 34 | self.lambd = lambd 35 | 36 | def get_lambd(self): 37 | return self.lambd 38 | 39 | def step(self): 40 | pass 41 | 42 | PAD_IDX = 0 43 | 44 | def create_similarity_distirbution(similarity_indices, similarity_values, temperature, positive_indices): 45 | num_items = similarity_indices.shape[0] 46 | num_positives = positive_indices.shape[0] 47 | # (num_positives, top_k_similar) 48 | pos_similarity_indices = torch.index_select(similarity_indices, index=positive_indices, dim=0) 49 | pos_similarity_values = torch.index_select(similarity_values, index=positive_indices, dim=0) 50 | 51 | # (num_positives, num_items) 52 | similarities = torch.full((num_positives, num_items), fill_value=-float('inf'), device=similarity_indices.device) 53 | similarities.scatter_(dim=1, index=pos_similarity_indices, src=pos_similarity_values) 54 | 55 | similarities /= temperature 56 | 57 | distribution = torch.nn.functional.softmax(similarities, dim=-1) 58 | return distribution 59 | 60 | # sampler for batch generation 61 | def random_neq(l, r, s): 62 | t = np.random.randint(l, r) 63 | while t in s: 64 | t = np.random.randint(l, r) 65 | return t 66 | 67 | 68 | def sample_function(user_train, usernum, itemnum, batch_size, maxlen, result_queue, SEED): 69 | def sample(): 70 | 71 | user = np.random.randint(1, usernum + 1) 72 | while len(user_train[user]) <= 1: 73 | user = np.random.randint(1, usernum + 1) 74 | 75 | seq = np.zeros([maxlen], dtype=np.int32) 76 | pos = np.zeros([maxlen], dtype=np.int32) 77 | neg = np.zeros([maxlen], dtype=np.int32) 78 | nxt = user_train[user][-1] 79 | idx = maxlen - 1 80 | 81 | ts = set(user_train[user]) 82 | for i in reversed(user_train[user][:-1]): 83 | seq[idx] = i 84 | pos[idx] = nxt 85 | if nxt != 0: 86 | neg[idx] = random_neq(1, itemnum + 1, ts) 87 | nxt = i 88 | idx -= 1 89 | if idx == -1: 90 | break 91 | 92 | return user, seq, pos, neg 93 | 94 | np.random.seed(SEED) 95 | while True: 96 | one_batch = [] 97 | for i in range(batch_size): 98 | one_batch.append(sample()) 99 | 100 | result_queue.put(zip(*one_batch)) 101 | 102 | 103 | class WarpSampler(object): 104 | def __init__(self, User, usernum, itemnum, batch_size=64, maxlen=10, n_workers=1): 105 | self.result_queue = Queue(maxsize=n_workers * 10) 106 | self.processors = [] 107 | for i in range(n_workers): 108 | self.processors.append( 109 | Process(target=sample_function, args=(User, 110 | usernum, 111 | itemnum, 112 | batch_size, 113 | maxlen, 114 | self.result_queue, 115 | np.random.randint(2e9) 116 | ))) 117 | self.processors[-1].daemon = True 118 | self.processors[-1].start() 119 | 120 | def next_batch(self): 121 | return self.result_queue.get() 122 | 123 | def close(self): 124 | for p in self.processors: 125 | p.terminate() 126 | p.join() 127 | 128 | 129 | # train/val/test data generation 130 | def data_partition(fname, augmentations_fname=None): 131 | usernum = 0 132 | itemnum = 0 133 | User = defaultdict(list) 134 | user_train = {} 135 | user_valid = {} 136 | user_test = {} 137 | # assume user/item index starting from 1 138 | with open(fname, 'r') as f: 139 | for line in f: 140 | u, i = line.rstrip().split(' ') 141 | u = int(u) 142 | i = int(i) 143 | usernum = max(u, usernum) 144 | itemnum = max(i, itemnum) 145 | User[u].append(i) 146 | if augmentations_fname is not None: 147 | with open(f'data/{augmentations_fname}.txt', 'r') as f: 148 | for line in f: 149 | u, i = line.rstrip().split(' ') 150 | u = int(u) 151 | i = int(i) 152 | usernum = max(u, usernum) 153 | itemnum = max(i, itemnum) 154 | User[u].append(i) 155 | 156 | for user in User: 157 | nfeedback = len(User[user]) 158 | if nfeedback < 3: 159 | user_train[user] = User[user] 160 | user_valid[user] = [] 161 | user_test[user] = [] 162 | else: 163 | user_train[user] = User[user][:-2] 164 | user_valid[user] = [] 165 | user_valid[user].append(User[user][-2]) 166 | user_test[user] = [] 167 | user_test[user].append(User[user][-1]) 168 | return [user_train, user_valid, user_test, usernum, itemnum] 169 | 170 | # evaluate on test set 171 | def evaluate_test(model, dataset, args): 172 | 173 | [train, valid, test, usernum, itemnum] = copy.deepcopy(dataset) 174 | 175 | NDCG = 0.0 176 | HT = 0.0 177 | valid_user = 0.0 178 | id_hr = defaultdict(list) 179 | id_ndcg = defaultdict(list) 180 | if usernum>10000: 181 | users = random.sample(range(1, usernum + 1), 10000) 182 | else: 183 | users = range(1, usernum + 1) 184 | for u in users: 185 | 186 | if len(train[u]) < 1 or len(test[u]) < 1: continue 187 | 188 | seq = np.zeros([args.maxlen], dtype=np.int32) 189 | idx = args.maxlen - 1 190 | seq[idx] = valid[u][0] 191 | idx -= 1 192 | for i in reversed(train[u]): 193 | seq[idx] = i 194 | idx -= 1 195 | if idx == -1: break 196 | rated = set(train[u]) 197 | rated.add(0) 198 | item_idx = [test[u][0]] 199 | for _ in range(100): 200 | t = np.random.randint(1, itemnum + 1) 201 | while t in rated: t = np.random.randint(1, itemnum + 1) 202 | item_idx.append(t) 203 | 204 | predictions = -model.predict(*[np.array(l) for l in [[u], [seq], item_idx]]) 205 | predictions = predictions[0] # - for 1st argsort DESC 206 | 207 | rank = predictions.argsort().argsort()[0].item() 208 | 209 | valid_user += 1 210 | if rank < 10: 211 | ndcg = 1 / np.log2(rank + 2) 212 | NDCG += ndcg 213 | HT += 1 214 | id_hr[item_idx[0]].append(1) 215 | id_ndcg[item_idx[0]].append(ndcg) 216 | else: 217 | id_hr[item_idx[0]].append(0) 218 | id_ndcg[item_idx[0]].append(0) 219 | if valid_user % 100 == 0: 220 | print('.', end="") 221 | sys.stdout.flush() 222 | return (NDCG / valid_user, HT / valid_user), id_hr, id_ndcg 223 | 224 | 225 | # evaluate on val set 226 | def evaluate_valid(model, dataset, args): 227 | [train, valid, test, usernum, itemnum] = copy.deepcopy(dataset) 228 | 229 | NDCG = 0.0 230 | HT = 0.0 231 | valid_user = 0.0 232 | id_hr = defaultdict(list) 233 | id_ndcg = defaultdict(list) 234 | if usernum>10000: 235 | users = random.sample(range(1, usernum + 1), 10000) 236 | else: 237 | users = range(1, usernum + 1) 238 | for u in users: 239 | if len(train[u]) < 1 or len(valid[u]) < 1: continue 240 | 241 | seq = np.zeros([args.maxlen], dtype=np.int32) 242 | idx = args.maxlen - 1 243 | for i in reversed(train[u]): 244 | seq[idx] = i 245 | idx -= 1 246 | if idx == -1: break 247 | 248 | rated = set(train[u]) 249 | rated.add(0) 250 | item_idx = [valid[u][0]] 251 | for _ in range(100): 252 | t = np.random.randint(1, itemnum + 1) 253 | while t in rated: t = np.random.randint(1, itemnum + 1) 254 | item_idx.append(t) 255 | 256 | predictions = -model.predict(*[np.array(l) for l in [[u], [seq], item_idx]]) 257 | predictions = predictions[0] 258 | 259 | rank = predictions.argsort().argsort()[0].item() 260 | 261 | valid_user += 1 262 | if rank < 10: 263 | ndcg = 1 / np.log2(rank + 2) 264 | NDCG += ndcg 265 | HT += 1 266 | id_hr[item_idx[0]].append(1) 267 | id_ndcg[item_idx[0]].append(ndcg) 268 | else: 269 | id_hr[item_idx[0]].append(0) 270 | id_ndcg[item_idx[0]].append(0) 271 | if valid_user % 100 == 0: 272 | print('.', end="") 273 | sys.stdout.flush() 274 | return (NDCG / valid_user, HT / valid_user), id_hr, id_ndcg 275 | 276 | # evaluate on train set 277 | def evaluate_train(model, dataset, args): 278 | [train, valid, test, usernum, itemnum] = copy.deepcopy(dataset) 279 | 280 | NDCG = 0.0 281 | HT = 0.0 282 | valid_user = 0.0 283 | id_hr = defaultdict(list) 284 | id_ndcg = defaultdict(list) 285 | if usernum>10000: 286 | users = random.sample(range(1, usernum + 1), 10000) 287 | else: 288 | users = range(1, usernum + 1) 289 | for u in users: 290 | if len(train[u]) < 1: continue 291 | 292 | seq = np.zeros([args.maxlen], dtype=np.int32) 293 | idx = args.maxlen - 1 294 | for i in reversed(train[u][:-1]): 295 | seq[idx] = i 296 | idx -= 1 297 | if idx == -1: break 298 | 299 | rated = set(train[u][:-1]) 300 | rated.add(0) 301 | item_idx = [train[u][-1]] 302 | for _ in range(100): 303 | t = np.random.randint(1, itemnum + 1) 304 | while t in rated: t = np.random.randint(1, itemnum + 1) 305 | item_idx.append(t) 306 | 307 | predictions = -model.predict(*[np.array(l) for l in [[u], [seq], item_idx]]) 308 | predictions = predictions[0] 309 | 310 | rank = predictions.argsort().argsort()[0].item() 311 | 312 | valid_user += 1 313 | if rank < 10: 314 | ndcg = 1 / np.log2(rank + 2) 315 | NDCG += ndcg 316 | HT += 1 317 | id_hr[item_idx[0]].append(1) 318 | id_ndcg[item_idx[0]].append(ndcg) 319 | else: 320 | id_hr[item_idx[0]].append(0) 321 | id_ndcg[item_idx[0]].append(0) 322 | if valid_user % 100 == 0: 323 | print('.', end="") 324 | sys.stdout.flush() 325 | return (NDCG / valid_user, HT / valid_user), id_hr, id_ndcg -------------------------------------------------------------------------------- /SimRec/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | import json 5 | import argparse 6 | from functools import partial 7 | 8 | from model import SASRec 9 | from utils import * 10 | 11 | def str2bool(s): 12 | if s not in {'false', 'true', 'False', 'True'}: 13 | raise ValueError('Not a valid boolean string') 14 | return s == 'true' or s == 'True' 15 | 16 | def parse_args(): 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--dataset', default="../data_preprocessing/Beauty/Beauty.txt") 19 | parser.add_argument('--item_frequency', default="../data_preprocessing/Beauty/Beauty-item_freq.txt") 20 | 21 | parser.add_argument('--similarity_indices', default="../data_preprocessing/Beauty/Beauty-similarity-indices-thenlper_gte-large.pt", type=str) 22 | parser.add_argument('--similarity_values', default="../data_preprocessing/Beauty/Beauty-similarity-values-thenlper_gte-large.pt", type=str) 23 | parser.add_argument('--similarity_threshold', default=0.9, type=float, help="zero similarities that are lower than the threshold (cosine similarity is between [-1,1])") 24 | 25 | parser.add_argument('--temperature', default=1, type=float, help="softmax temperature for training") 26 | parser.add_argument('--lambd', default=0.5, type=float, help="control the weight of the 'distilation' loss") 27 | parser.add_argument('--lambd_scheduling', default="LINEAR", choices=["LINEAR", "NONE"]) 28 | parser.add_argument('--lambd_warmup_steps', default=1000, type=int, help="control the number of warmup steps for the lambda scheduler") 29 | parser.add_argument('--lambd_steps', default=70000, type=int, help="control the number of lambda steps for the lambda binary scheduler") 30 | 31 | parser.add_argument('--train_dir', default='train_beauty') 32 | 33 | parser.add_argument('--batch_size', default=128, type=int) 34 | parser.add_argument('--lr', default=0.001, type=float) 35 | parser.add_argument('--maxlen', default=50, type=int) 36 | parser.add_argument('--hidden_units', default=50, type=int) 37 | parser.add_argument('--num_blocks', default=2, type=int) 38 | parser.add_argument('--num_epochs', default=201, type=int) 39 | parser.add_argument('--num_heads', default=1, type=int) 40 | parser.add_argument('--dropout_rate', default=0.5, type=float) 41 | parser.add_argument('--l2_emb', default=0.0, type=float) 42 | 43 | parser.add_argument('--device', default='cuda', type=str) 44 | parser.add_argument('--inference_only', default='false', type=str2bool) 45 | parser.add_argument('--state_dict_path', type=str) 46 | args = parser.parse_args() 47 | return args 48 | 49 | 50 | def main(args): 51 | print(args) 52 | 53 | item_freq_df = pd.read_csv(args.item_frequency, delimiter=' ', names=['id', 'freq']) 54 | item_freq_df['id'] += 1 # id 0 is reserved for PAD 55 | item_freq = pd.Series(item_freq_df.freq.values, index=item_freq_df.id).to_dict() 56 | 57 | if not os.path.isdir(args.train_dir): 58 | os.makedirs(args.train_dir) 59 | with open(os.path.join(args.train_dir, 'args.txt'), 'w') as f: 60 | f.write('\n'.join([str(k) + ',' + str(v) for k, v in sorted(vars(args).items(), key=lambda x: x[0])])) 61 | 62 | # global dataset 63 | dataset = data_partition(args.dataset) 64 | 65 | 66 | if args.similarity_indices is None or args.similarity_values is None: 67 | raise Exception("args.similarity_indices or args.similarity_values are None") 68 | similarity_indices = torch.load(args.similarity_indices, map_location=args.device) 69 | similarity_values = torch.load(args.similarity_values, map_location=args.device) 70 | if args.similarity_threshold < 1: 71 | similarity_values[similarity_values <= args.similarity_threshold] = -float('inf') 72 | else: 73 | # make the self similarity maximal 74 | similarity_indices = torch.arange(similarity_indices.shape[0], device=args.device).reshape(-1, 1) 75 | similarity_values = torch.ones_like(similarity_indices) 76 | 77 | # handle padding index (0) by increasing the index values 78 | similarity_indices += 1 79 | 80 | similarity_indices = torch.concat([torch.arange((similarity_indices.shape[1]), device=args.device).unsqueeze(dim=0), similarity_indices], dim=0) 81 | similarity_values = torch.concat([torch.full((1, similarity_values.shape[1]), fill_value=-float('inf'), device=args.device), similarity_values], dim=0) 82 | 83 | 84 | [user_train, user_valid, user_test, usernum, itemnum] = dataset 85 | num_batch = len(user_train) // args.batch_size # tail? + ((len(user_train) % args.batch_size) != 0) 86 | training_steps = num_batch * args.num_epochs 87 | cc = 0.0 88 | for u in user_train: 89 | cc += len(user_train[u]) 90 | print(f'average sequence length: {cc / len(user_train):.2f}') 91 | print(f"{training_steps} training steps in total") 92 | 93 | lambda_scheduler = NoneSchedule(args.lambd) 94 | if args.lambd_scheduling == 'LambdaScheduling.LINEAR': 95 | lambda_scheduler = LinearScheduleWithWarmup(args.lambd, args.lambd_warmup_steps, min(training_steps, args.lambd_steps)) 96 | 97 | f = open(os.path.join(args.train_dir, 'log.txt'), 'w') 98 | 99 | sampler = WarpSampler(user_train, usernum, itemnum, batch_size=args.batch_size, maxlen=args.maxlen, n_workers=3) 100 | model = SimRec(usernum, itemnum, args).to(args.device) 101 | 102 | model_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 103 | print(f"{model_total_params} model parameters") 104 | 105 | model.train() 106 | 107 | epoch_start_idx = 1 108 | if args.state_dict_path is not None: 109 | try: 110 | model.load_state_dict(torch.load(args.state_dict_path, map_location=torch.device(args.device))) 111 | tail = args.state_dict_path[args.state_dict_path.find('epoch=') + 6:] 112 | epoch_start_idx = int(tail[:tail.find('.')]) + 1 113 | except: # in case your pytorch version is not 1.6 etc., pls debug by pdb if load weights failed 114 | print('failed loading state_dicts, pls check file path: ', end="") 115 | print(args.state_dict_path) 116 | print('pdb enabled for your quick check, pls type exit() if you do not need it') 117 | import pdb; pdb.set_trace() 118 | 119 | 120 | if args.inference_only: 121 | model.eval() 122 | test_eval_results = [evaluate_test(model, dataset, args) for _ in range(5)] 123 | 124 | test_id_hr = test_eval_results[-1][-2] 125 | test_id_ndcg = test_eval_results[-1][-1] 126 | 127 | test_hr_list = [e[0][1] for e in test_eval_results] 128 | test_ndcg_list = [e[0][0] for e in test_eval_results] 129 | 130 | test_hr = np.array(test_hr_list).mean() 131 | test_ndcg = np.array(test_ndcg_list).mean() 132 | print(f'test (NDCG@10: {test_hr:.4f}, HR@10: {test_ndcg:.4f})') 133 | 134 | 135 | bce_criterion = torch.nn.BCEWithLogitsLoss() 136 | cross_entropy_criterion = torch.nn.CrossEntropyLoss() 137 | adam_optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.98)) 138 | 139 | T = 0.0 140 | t0 = time.time() 141 | fname = f'SimRec.epoch={args.num_epochs}.pth' 142 | best_val_hr = -float('inf') 143 | best_results = {} 144 | global_step = 0 145 | for epoch in range(epoch_start_idx, args.num_epochs + 1): 146 | if args.inference_only: 147 | break 148 | for step in range(num_batch): 149 | lambd = lambda_scheduler.get_lambd() 150 | u, seq, pos, neg = sampler.next_batch() 151 | u, seq, pos, neg = np.array(u), np.array(seq), np.array(pos), np.array(neg) 152 | 153 | pos_logits, neg_logits, logits = model(u, seq, pos, neg) 154 | pos_labels, neg_labels = torch.ones(pos_logits.shape, device=args.device), torch.zeros(neg_logits.shape, device=args.device) 155 | 156 | adam_optimizer.zero_grad() 157 | indices = np.where(pos != PAD_IDX) 158 | loss = bce_criterion(pos_logits[indices], pos_labels[indices]) 159 | loss += bce_criterion(neg_logits[indices], neg_labels[indices]) 160 | 161 | pos = torch.tensor(pos, device=args.device, dtype=torch.long).view(-1) # (batch_size x max_len) 162 | pad_indices = pos == PAD_IDX 163 | pos = pos[~pad_indices] 164 | logits_flat = logits.view(-1, itemnum + 1) # (batch_size x max_len ,num_items + 1) 165 | logits_flat = logits_flat[~pad_indices] 166 | 167 | targets_distribution = create_similarity_distirbution(similarity_indices, similarity_values, args.temperature, pos) # (batch_size x max_len, num_items + 1) 168 | 169 | loss = lambd * cross_entropy_criterion(logits_flat / args.temperature, targets_distribution) + (1 - lambd) * loss 170 | 171 | if args.l2_emb > 0: 172 | for param in model.item_emb.parameters(): 173 | loss += args.l2_emb * torch.norm(param) 174 | loss.backward() 175 | adam_optimizer.step() 176 | lambda_scheduler.step() 177 | global_step += 1 178 | print(f"loss in epoch {epoch}: {loss.item()}") 179 | 180 | if epoch % 20 == 0 or epoch == args.num_epochs: 181 | model.eval() 182 | t1 = time.time() - t0 183 | T += t1 184 | print('Evaluating', end='') 185 | test_eval_results = [evaluate_test(model, dataset, args) for _ in range(5)] 186 | val_eval_results = [evaluate_valid(model, dataset, args) for _ in range(5)] 187 | 188 | test_id_hr = test_eval_results[-1][-2] 189 | test_id_ndcg = test_eval_results[-1][-1] 190 | valid_id_hr = val_eval_results[-1][-2] 191 | valid_id_ndcg = val_eval_results[-1][-1] 192 | 193 | test_hr_list = [e[0][1] for e in test_eval_results] 194 | test_ndcg_list = [e[0][0] for e in test_eval_results] 195 | val_hr_list = [e[0][1] for e in val_eval_results] 196 | val_ndcg_list = [e[0][0] for e in val_eval_results] 197 | 198 | # we take the average of different test/val runs 199 | test_hr = np.array(test_hr_list).mean() 200 | test_ndcg = np.array(test_ndcg_list).mean() 201 | val_hr = np.array(val_hr_list).mean() 202 | val_ndcg = np.array(val_ndcg_list).mean() 203 | 204 | (train_ndcg, train_hr), train_id_hr, train_id_ndcg = evaluate_train(model, dataset, args) 205 | s = '\nepoch:%d, time: %f(s), train (NDCG@10: %.4f, HR@10: %.4f), valid (NDCG@10: %.4f, HR@10: %.4f), test (NDCG@10: %.4f, HR@10: %.4f)'\ 206 | % (epoch, T, train_ndcg, train_hr, val_ndcg, val_hr, test_ndcg, test_hr) 207 | print(s) 208 | f.write(s + '\n') 209 | f.flush() 210 | if val_hr > best_val_hr: 211 | torch.save(model.state_dict(), os.path.join(args.train_dir, fname)) 212 | torch.save(test_id_hr, os.path.join(args.train_dir, "test_id_hr.pt")) 213 | torch.save(test_id_ndcg, os.path.join(args.train_dir, "test_id_ndcg.pt")) 214 | torch.save(valid_id_hr, os.path.join(args.train_dir, "valid_id_hr.pt")) 215 | torch.save(valid_id_ndcg, os.path.join(args.train_dir, "valid_id_ndcg.pt")) 216 | torch.save(train_id_hr, os.path.join(args.train_dir, "train_id_hr.pt")) 217 | torch.save(train_id_ndcg, os.path.join(args.train_dir, "train_id_ndcg.pt")) 218 | print(f"New best val HR@10: {val_hr}. Saving checkpoint.") 219 | best_val_hr = max(best_val_hr, val_hr) 220 | best_results = { 221 | 'test/best_id_to_HR@10': test_id_hr, 222 | 'test/best_id_to_NDCG@10': test_id_ndcg, 223 | 'val/best_id_to_HR@10': valid_id_hr, 224 | 'val/best_id_to_NDCG@10': valid_id_ndcg, 225 | 'train/best_id_to_HR@10': train_id_hr, 226 | 'train/best_id_to_NDCG@10': train_id_ndcg, 227 | "test/best_NDCG@10": test_ndcg, 228 | "test/best_HR@10": test_hr, 229 | "val/best_NDCG@10": val_ndcg, 230 | "val/best_HR@10": val_hr, 231 | "train/best_NDCG@10": train_ndcg, 232 | "train/best_HR@10": train_hr, 233 | } 234 | t0 = time.time() 235 | model.train() 236 | f.close() 237 | sampler.close() 238 | print("Done") 239 | 240 | 241 | if __name__ == '__main__': 242 | args = parse_args() 243 | main(args) --------------------------------------------------------------------------------