├── .gitignore ├── CODEOWNERS ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING-ARCHIVED.md ├── LICENSE.txt ├── README.md ├── RobustnessGymDataset.py ├── RobustnessGymRecsys.py ├── SECURITY.md └── utils └── GlobalVars.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | -------------------------------------------------------------------------------- /CODEOWNERS: -------------------------------------------------------------------------------- 1 | # Comment line immediately above ownership line is reserved for related gus information. Please be careful while editing. 2 | #ECCN:Open Source 3 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Salesforce Open Source Community Code of Conduct 2 | 3 | ## About the Code of Conduct 4 | 5 | Equality is a core value at Salesforce. We believe a diverse and inclusive 6 | community fosters innovation and creativity, and are committed to building a 7 | culture where everyone feels included. 8 | 9 | Salesforce open-source projects are committed to providing a friendly, safe, and 10 | welcoming environment for all, regardless of gender identity and expression, 11 | sexual orientation, disability, physical appearance, body size, ethnicity, nationality, 12 | race, age, religion, level of experience, education, socioeconomic status, or 13 | other similar personal characteristics. 14 | 15 | The goal of this code of conduct is to specify a baseline standard of behavior so 16 | that people with different social values and communication styles can work 17 | together effectively, productively, and respectfully in our open source community. 18 | It also establishes a mechanism for reporting issues and resolving conflicts. 19 | 20 | All questions and reports of abusive, harassing, or otherwise unacceptable behavior 21 | in a Salesforce open-source project may be reported by contacting the Salesforce 22 | Open Source Conduct Committee at ossconduct@salesforce.com. 23 | 24 | ## Our Pledge 25 | 26 | In the interest of fostering an open and welcoming environment, we as 27 | contributors and maintainers pledge to making participation in our project and 28 | our community a harassment-free experience for everyone, regardless of gender 29 | identity and expression, sexual orientation, disability, physical appearance, 30 | body size, ethnicity, nationality, race, age, religion, level of experience, education, 31 | socioeconomic status, or other similar personal characteristics. 32 | 33 | ## Our Standards 34 | 35 | Examples of behavior that contributes to creating a positive environment 36 | include: 37 | 38 | * Using welcoming and inclusive language 39 | * Being respectful of differing viewpoints and experiences 40 | * Gracefully accepting constructive criticism 41 | * Focusing on what is best for the community 42 | * Showing empathy toward other community members 43 | 44 | Examples of unacceptable behavior by participants include: 45 | 46 | * The use of sexualized language or imagery and unwelcome sexual attention or 47 | advances 48 | * Personal attacks, insulting/derogatory comments, or trolling 49 | * Public or private harassment 50 | * Publishing, or threatening to publish, others' private information—such as 51 | a physical or electronic address—without explicit permission 52 | * Other conduct which could reasonably be considered inappropriate in a 53 | professional setting 54 | * Advocating for or encouraging any of the above behaviors 55 | 56 | ## Our Responsibilities 57 | 58 | Project maintainers are responsible for clarifying the standards of acceptable 59 | behavior and are expected to take appropriate and fair corrective action in 60 | response to any instances of unacceptable behavior. 61 | 62 | Project maintainers have the right and responsibility to remove, edit, or 63 | reject comments, commits, code, wiki edits, issues, and other contributions 64 | that are not aligned with this Code of Conduct, or to ban temporarily or 65 | permanently any contributor for other behaviors that they deem inappropriate, 66 | threatening, offensive, or harmful. 67 | 68 | ## Scope 69 | 70 | This Code of Conduct applies both within project spaces and in public spaces 71 | when an individual is representing the project or its community. Examples of 72 | representing a project or community include using an official project email 73 | address, posting via an official social media account, or acting as an appointed 74 | representative at an online or offline event. Representation of a project may be 75 | further defined and clarified by project maintainers. 76 | 77 | ## Enforcement 78 | 79 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 80 | reported by contacting the Salesforce Open Source Conduct Committee 81 | at ossconduct@salesforce.com. All complaints will be reviewed and investigated 82 | and will result in a response that is deemed necessary and appropriate to the 83 | circumstances. The committee is obligated to maintain confidentiality with 84 | regard to the reporter of an incident. Further details of specific enforcement 85 | policies may be posted separately. 86 | 87 | Project maintainers who do not follow or enforce the Code of Conduct in good 88 | faith may face temporary or permanent repercussions as determined by other 89 | members of the project's leadership and the Salesforce Open Source Conduct 90 | Committee. 91 | 92 | ## Attribution 93 | 94 | This Code of Conduct is adapted from the [Contributor Covenant][contributor-covenant-home], 95 | version 1.4, available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html. 96 | It includes adaptions and additions from [Go Community Code of Conduct][golang-coc], 97 | [CNCF Code of Conduct][cncf-coc], and [Microsoft Open Source Code of Conduct][microsoft-coc]. 98 | 99 | This Code of Conduct is licensed under the [Creative Commons Attribution 3.0 License][cc-by-3-us]. 100 | 101 | [contributor-covenant-home]: https://www.contributor-covenant.org (https://www.contributor-covenant.org/) 102 | [golang-coc]: https://golang.org/conduct 103 | [cncf-coc]: https://github.com/cncf/foundation/blob/master/code-of-conduct.md 104 | [microsoft-coc]: https://opensource.microsoft.com/codeofconduct/ 105 | [cc-by-3-us]: https://creativecommons.org/licenses/by/3.0/us/ -------------------------------------------------------------------------------- /CONTRIBUTING-ARCHIVED.md: -------------------------------------------------------------------------------- 1 | # ARCHIVED 2 | 3 | This project is `Archived` and is no longer actively maintained; 4 | We are not accepting contributions or Pull Requests. 5 | 6 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2021, Salesforce.com, Inc. 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 9 | 10 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 11 | 12 | 3. Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 13 | 14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RGRecSys 2 | RGRecSys is a robustness evaluation toolkit for recommendation systems. Currently, the robustness tools in RGRecSys can be tested on general and context-aware recommendation models from [RecBole library](https://dl.acm.org/doi/abs/10.1145/3459637.3482016). Our toolkit allows users to easily and uniformly evaluate recommender system robustness. See our [paper](https://dl.acm.org/doi/abs/10.1145/3488560.3502192) for more details. 3 | 4 | ## Requirements 5 | 6 | - python >= 3.7 7 | - torch >= 1.7.0 8 | - numpy >= 1.17.2 9 | - pandas >= 1.0.5 10 | - scikit-learn >= 0.23.2 11 | - recbole == 0.2.1 12 | 13 | ## Usage 14 | 15 | 1. If you want to use datasets other than *ml-100k*, you need to use RecBole library to get the atomic files. Otherwise, you can skip this step. The atomic files provide a data representation for different recommendation algorithms including .INTER, .USER, and .ITEM. See [this](https://dl.acm.org/doi/abs/10.1145/3459637.3482016) for more information. 16 | 2. Create a folder named as "saved" in the RGRecSys-master folder. 17 | 3. Specify the model, dataset, and desired robustness test in the main function of the RobustnessGymRecSys.py following the example below: 18 | 19 | ```python 20 | if __name__ == '__main__': 21 | for model in [“BPR”]: #Specify model here 22 | dataset = "ml-100k" #Specify dataset here 23 | base_config_dict = { #Specify selectively loading data here. Keys are the suffix of loaded atomic files, values are the field name list to be loaded 24 | 'load_col': { 25 | 'inter': ['user_id', 'item_id', 'rating', 'timestamp'], 26 | 'user': ['user_id', 'age', 'gender','occupation'], 27 | 'item': ['item_id', 'release_year', 'class'] 28 | } 29 | } 30 | robustness_dict = { #Specify the robustness test here. This example shows slicing based on user feature 31 | "slice": { 32 | "by_feature": { 33 | "occupation": {"equal": "student"} 34 | } 35 | } 36 | } 37 | results = train_and_test( 38 | model=model, 39 | dataset=dataset, 40 | robustness_tests=robustness_dict, 41 | base_config_dict=base_config_dict, 42 | save_model=False 43 | ) 44 | ``` 45 | 46 | Below is more examples of different robustness test formatting: 47 | 48 | ###### Slice Test Data by Feature 49 | 50 | ```python 51 | #Ex: A slice of users whose occupation is student 52 | #Format: ”user feature”: {“equal, min, or max”: “value”} 53 | "slice": { 54 | "by_feature": { 55 | "occupation": {"equal": "student"} 56 | } 57 | } 58 | ``` 59 | ###### Slice Test Data by Interaction 60 | 61 | ```python 62 | #Ex: A slice of users whose number of interactions is more than 50 63 | #Format: ”user”: {“equal, min, or max”: # interactions} 64 | "slice": { 65 | "by_inter": { 66 | "user": {"min": 50} 67 | } 68 | } 69 | ``` 70 | 71 | ###### Sparsify the Training Data 72 | 73 | ```python 74 | #Ex: randomly drop 25% of interactions for users whose number of interactions is more than 10 75 | #Format: ”min_user_inter”: min num of inter for each user, ”fraction_removed”: fraction of interaction to remove 76 | "sparsify": { 77 | "min_user_inter": 10, 78 | "fraction_removed": .25 79 | } 80 | ``` 81 | 82 | ###### Transform the Test Data - Structured 83 | 84 | ```python 85 | #Ex: users age will be replaced with a value between 0.8 of their original age to 1.2 of their original age (user with age 10 will have an age value randomly selected from 8-12) 86 | #Format: ”user or item feature”: fraction of current value that will be added or subtracted from the original value 87 | "transform_features": { 88 | "structured": { 89 | "age": 0.2, 90 | } 91 | } 92 | ``` 93 | 94 | 95 | ###### Transform the Test Data - Random 96 | 97 | ```python 98 | #Ex: change 40% of user gender value to any other gender value 99 | #Format: ”user or item feature”: fraction to change 100 | "transform_features": { 101 | "random": { 102 | "gender": .40, 103 | } 104 | } 105 | ``` 106 | 107 | ###### Transform the Training Interactions - Random Attack 108 | 109 | ```python 110 | #Ex: 10% of usser interaction are transformed to other values 111 | #Format: ”fraction_transformed”: fraction to transform 112 | "transform_interactions": { 113 | "fraction_transformed": 0.1 114 | } 115 | ``` 116 | 117 | ###### Distribution Shift in the Test Set 118 | 119 | ```python 120 | #Ex: manipulate test set to contain 50% male and 50% female 121 | #Format: ”user feature”: {proportions of each feature value} 122 | "distribution_shift": { 123 | "gender": { 124 | "M": .5, 125 | "F": .5 126 | } 127 | } 128 | ``` 129 | 130 | 131 | 132 | ## Cite 133 | 134 | If you aim to use RGRecSys for your research or development, please cite the following paper: 135 | ``` 136 | @inproceedings{10.1145/3488560.3502192, 137 | author = {Ovaisi, Zohreh and Heinecke, Shelby and Li, Jia and Zhang, Yongfeng and Zheleva, Elena and Xiong, Caiming}, 138 | title = {RGRecSys: A Toolkit for Robustness Evaluation of Recommender Systems}, 139 | url = {https://dl.acm.org/doi/abs/10.1145/3488560.3502192}, 140 | doi = {10.1145/3488560.3502192}, 141 | series = {WSDM '22} 142 | } 143 | ``` 144 | -------------------------------------------------------------------------------- /RobustnessGymDataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | /* 3 | * Copyright (c) 2021, salesforce.com, inc. 4 | * All rights reserved. 5 | * SPDX-License-Identifier: BSD-3-Clause 6 | * For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 7 | */ 8 | """ 9 | 10 | from utils.GlobalVars import * 11 | import copy 12 | import pandas as pd 13 | from sklearn.utils import shuffle as sk_shuffle 14 | from recbole.utils.utils import set_color, ModelType, init_seed 15 | from recbole.data.dataset import Dataset 16 | from recbole.config import Config, EvalSetting 17 | from recbole.utils.enum_type import FeatureType 18 | from collections.abc import Iterable 19 | from collections import Counter 20 | import random 21 | import numpy as np 22 | import torch 23 | 24 | 25 | class RobustnessGymDataset(Dataset): 26 | """ 27 | A RobustnessGymDataset is a modified Dataset. 28 | """ 29 | 30 | def __init__(self, config): 31 | """ 32 | 33 | Args: 34 | config (Config): 35 | """ 36 | super().__init__(config) 37 | 38 | def _data_filtering(self): 39 | """ 40 | Filters data by removing nans, removing duplications, 41 | updating interaction if nans/duplications removed, 42 | and resetting index. 43 | """ 44 | self._filter_nan_user_or_item() 45 | self._remove_duplication() 46 | self._filter_inter_by_user_or_item() 47 | self._reset_index() 48 | 49 | def copy(self, new_inter_feat): 50 | """ 51 | Overloaded copy() in RecBole. This deep copies RobustnessGymDataset and sets inter_feat. 52 | Args: 53 | new_inter_feat (RobustnessGymDataset): 54 | 55 | Returns: 56 | nxt (RobustnessGymDataset): 57 | """ 58 | nxt = copy.deepcopy(self) 59 | nxt.inter_feat = new_inter_feat 60 | return nxt 61 | 62 | def split_by_ratio(self, ratios, group_by=None): 63 | """ 64 | Overloaded split_by_ratio in RecBole. 65 | Main difference - we split RobustnessGymDataset instance (instead of 66 | Dataloader instance) into train, valid, and test. 67 | Args: 68 | ratios (list): 69 | group_by (): 70 | 71 | Returns: 72 | 73 | """ 74 | self.logger.debug(f'split by ratios [{ratios}], group_by=[{group_by}]') 75 | tot_ratio = sum(ratios) 76 | ratios = [_ / tot_ratio for _ in ratios] 77 | 78 | if group_by is None: 79 | tot_cnt = self.__len__() 80 | split_ids = self._calcu_split_ids(tot=tot_cnt, ratios=ratios) 81 | next_index = [range(start, end) for start, end in zip([0] + split_ids, split_ids + [tot_cnt])] 82 | else: 83 | grouped_inter_feat_index = self._grouped_index(self.inter_feat[group_by].to_numpy()) 84 | next_index = [[] for _ in range(len(ratios))] 85 | for grouped_index in grouped_inter_feat_index: 86 | tot_cnt = len(grouped_index) 87 | split_ids = self._calcu_split_ids(tot=tot_cnt, ratios=ratios) 88 | for index, start, end in zip(next_index, [0] + split_ids, split_ids + [tot_cnt]): 89 | index.extend(grouped_index[start:end]) 90 | 91 | self._drop_unused_col() 92 | next_df = [self.inter_feat.iloc[index] for index in next_index] 93 | next_ds = [self.copy(_) for _ in next_df] 94 | return next_ds 95 | 96 | def leave_one_out(self, group_by, leave_one_num=1): 97 | """ 98 | Overloaded leave_one_out in RecBole. Main difference - we split RobustnessGymDataset instance 99 | (instead of Dataloader instance) into train, valid, and test. 100 | Args: 101 | group_by: 102 | leave_one_num: 103 | 104 | Returns: 105 | 106 | """ 107 | self.logger.debug(f'leave one out, group_by=[{group_by}], leave_one_num=[{leave_one_num}]') 108 | if group_by is None: 109 | raise ValueError('leave one out strategy require a group field') 110 | 111 | grouped_inter_feat_index = self._grouped_index(self.inter_feat[group_by].numpy()) 112 | next_index = self._split_index_by_leave_one_out(grouped_inter_feat_index, leave_one_num) 113 | 114 | self._drop_unused_col() 115 | next_df = [self.inter_feat.iloc[index] for index in next_index] 116 | next_ds = [self.copy(_) for _ in next_df] 117 | return next_ds 118 | 119 | def _transform_by_field_value_random(self): 120 | """ 121 | Transforms x% of feature/field values by removing the current value and 122 | replacing with random value selected from set of all possible values. 123 | 124 | Returns: 125 | 126 | """ 127 | transform_percents = self.config['transform_val'] 128 | if transform_percents is None: 129 | return [] 130 | 131 | self.logger.debug(set_color('transform_by_field_value', 'blue') + f': val={transform_percents}') 132 | for field in transform_percents: 133 | if field not in self.field2type: 134 | raise ValueError(f'Field [{field}] not defined in dataset.') 135 | for feat_name in self.feat_name_list: 136 | feat = getattr(self, feat_name) 137 | if field in feat: 138 | # gather all possible field values 139 | field_values = [] 140 | for index, row in feat.iterrows(): 141 | if not isinstance(row[field], Iterable) and row[field] != 0 and row[field] not in field_values: 142 | field_values.append(row[field]) 143 | elif isinstance(row[field], Iterable) and len(row[field]) != 0: 144 | for i in row[field]: 145 | if i not in field_values: 146 | field_values.append(i) 147 | random_indices = random.sample(range(1, len(feat) - 1), 148 | round(transform_percents[field] * len(feat) - 1)) 149 | for i in random_indices: 150 | field_value_choices = field_values[:] 151 | if not isinstance(feat.iloc[i, feat.columns.get_loc(field)], Iterable): 152 | # remove current value and replace with another chosen at random 153 | field_value_choices.remove(feat.iloc[i, feat.columns.get_loc(field)]) 154 | feat.iloc[i, feat.columns.get_loc(field)] = random.choice(field_value_choices) 155 | elif isinstance(feat.iloc[i, feat.columns.get_loc(field)], Iterable): 156 | for j in feat.iloc[i, feat.columns.get_loc(field)]: 157 | field_value_choices.remove(j) 158 | # remove iterable and replace with ONE randomly chosen value 159 | feat.iloc[i, feat.columns.get_loc(field)] = np.array([[random.choice(field_value_choices)]]) 160 | return field_values 161 | 162 | def _transform_by_field_value_structured(self): 163 | """ 164 | Transforms field/feature in structured manner. 165 | 166 | (1) If feature value is a single value (float, int), then the value is replaced with a value within x% of the 167 | current value. For example, age = 30, x = 10% --> may be replaced with age = 32. 168 | (2) If feature value is an iterable (list, numpy array), then x% of the values are dropped. 169 | For example, genre = [Horror, Drama, Romance], x = 33% --> may be replaced with genre = [Horror, Romance] 170 | """ 171 | 172 | transform_percents = self.config['DropeFraction_or_variance_transform_val'] 173 | 174 | if transform_percents is None: 175 | return [] 176 | self.logger.debug(set_color('_transform_by_field_value', 'blue') + f': val={transform_percents}') 177 | 178 | for field in transform_percents: 179 | if field not in self.field2type: 180 | raise ValueError(f'Field [{field}] not defined in dataset.') 181 | for feat_name in self.feat_name_list: 182 | feat = getattr(self, feat_name) 183 | if field in feat: 184 | random_indices = random.sample(range(1, len(feat) - 1), 185 | round(transform_percents[field] * len(feat) - 1)) 186 | for i in random_indices: 187 | if not isinstance(feat.iloc[i, feat.columns.get_loc(field)], Iterable): 188 | # replaces current value with random integer within x% of current value 189 | random_value = random.randint( 190 | round((1 - transform_percents[field]) * feat.iloc[i, feat.columns.get_loc(field)]), 191 | round((1 + transform_percents[field]) * feat.iloc[i, feat.columns.get_loc(field)])) 192 | feat.iloc[i, feat.columns.get_loc(field)] = random_value 193 | elif isinstance(feat.iloc[i, feat.columns.get_loc(field)], Iterable) and len( 194 | feat.iloc[i, feat.columns.get_loc(field)]) > 1: 195 | # randomly sample x% from iterable/list and remove them 196 | dropped_values = random.sample(list(feat.iloc[i, feat.columns.get_loc(field)]), 197 | round(transform_percents[field] * 198 | len(feat.iloc[i, feat.columns.get_loc(field)]))) 199 | for item in dropped_values: 200 | feat.iat[i, feat.columns.get_loc(field)] = np.array( 201 | feat.iloc[i, feat.columns.get_loc(field)][ 202 | feat.iloc[i, feat.columns.get_loc(field)] != item]) 203 | 204 | def _transform_by_field_value_delete_feat(self): 205 | """ 206 | Transforms field by "deleting" x% of feature values. Since the feature value cannot be truly deleted, 207 | we instead remove x% of feature values and replace with the average value of the feature. 208 | """ 209 | 210 | delete_percent = self.config['DeleteFraction_transform_val'] 211 | if delete_percent is None: 212 | return [] 213 | 214 | self.logger.debug(set_color('_transform_by_field_value', 'blue') + f': val={delete_percent}') 215 | for field in delete_percent: 216 | if field not in self.field2type: 217 | raise ValueError(f'Field [{field}] not defined in dataset.') 218 | value_list = [] 219 | for feat_name in self.feat_name_list: 220 | feat = getattr(self, feat_name) 221 | if field in feat: 222 | # compute average value of feature/field 223 | for i in range(len(feat)): 224 | value_list.append(feat.iloc[i, feat.columns.get_loc(field)]) 225 | avg_value = np.mean(value_list) 226 | 227 | for feat_name in self.feat_name_list: 228 | feat = getattr(self, feat_name) 229 | if field in feat: 230 | random_indices = random.sample(range(1, len(feat) - 1), 231 | round(delete_percent[field] * len(feat) - 1)) 232 | for i in random_indices: 233 | if not isinstance(feat.iloc[i, feat.columns.get_loc(field)], Iterable): 234 | # replace with average value of feature 235 | feat.iloc[i, feat.columns.get_loc(field)] = avg_value 236 | 237 | def _make_data_more_sparse(self): 238 | """ 239 | 240 | Returns: 241 | 242 | """ 243 | val1 = self.config['selected_user_spars_data'] 244 | val2 = self.config['fraction_spars_data'] 245 | user_D = {} 246 | item_D = {} 247 | 248 | for line in range(len(self.inter_feat)): 249 | user_id = self.inter_feat.iloc[line]["user_id"] 250 | item_id = self.inter_feat.iloc[line]["item_id"] 251 | 252 | if user_id not in user_D: 253 | user_D[user_id] = [] 254 | user_D[user_id].append(item_id) 255 | if item_id not in item_D: 256 | item_D[item_id] = [] 257 | item_D[item_id].append(user_id) 258 | 259 | for user_id in user_D: 260 | if len(user_D[user_id]) > val1: 261 | selected_item_id = random.sample(user_D[user_id], round(val2 * len(user_D[user_id]))) 262 | for item in selected_item_id: 263 | self.inter_feat.drop(self.inter_feat.loc[self.inter_feat['user_id'] == user_id].loc[ 264 | self.inter_feat['item_id'] == item].index, inplace=True) 265 | 266 | def _transform_interactions_random(self): 267 | """ 268 | 269 | Returns: 270 | 271 | """ 272 | transform_fraction = self.config['transform_inter'] 273 | if transform_fraction is None: 274 | return [] 275 | 276 | random_rating = 0 277 | possible_values = [0.0, 1.0] 278 | random_rows = random.sample(list(self.inter_feat.index), round(transform_fraction * len(self.inter_feat))) 279 | for index in random_rows: 280 | if self.config['MODEL_TYPE'] == ModelType.GENERAL or self.config['MODEL_TYPE'] == ModelType.TRADITIONAL: 281 | transform_col = "rating" 282 | get_random_rating = True 283 | while get_random_rating: 284 | random_rating = round(random.uniform(possible_values[0], possible_values[1]), 2) 285 | if random_rating != self.inter_feat[transform_col].loc[index]: 286 | get_random_rating = False 287 | self.inter_feat[transform_col].loc[index] = random_rating 288 | if self.config['MODEL_TYPE'] == ModelType.CONTEXT: 289 | transform_col = "label" 290 | if self.inter_feat[transform_col].loc[index] == 1.0: 291 | self.inter_feat[transform_col].loc[index] = 0.0 292 | else: 293 | self.inter_feat[transform_col].loc[index] = 1.0 294 | 295 | @staticmethod 296 | def _get_user_or_item_subset(feat_file, field, val_list): 297 | """ 298 | 299 | Args: 300 | user_feat (Dataframe): 301 | feature (str): 302 | val_list (list): 303 | 304 | Returns: 305 | 306 | """ 307 | return {val: list(feat_file[feat_file[field] == val]) for val in val_list} 308 | 309 | def _distributional_slice_old(self): 310 | """ 311 | Older implementation of distribution shift based on removing prescribed 312 | proportions of test subpopulations. 313 | Returns: 314 | 315 | """ 316 | dist_slice = self.config['distribution_shift'] 317 | print(dist_slice) 318 | if dist_slice is None: 319 | return [] 320 | 321 | for field in dist_slice: 322 | distribution = dist_slice[field] 323 | distribution_keys = list(dist_slice[field].keys()) 324 | print(distribution) 325 | print(distribution_keys) 326 | print(len(self.inter_feat)) 327 | if field not in self.field2type: 328 | raise ValueError(f'Field [{field}] not defined in dataset.') 329 | if self.field2type[field] not in {FeatureType.TOKEN}: 330 | raise ValueError(f'Currently only works for Token types.') 331 | for feat_name in self.feat_name_list: 332 | feat = getattr(self, feat_name) 333 | if field in feat: 334 | user_dict = {} 335 | unique_vals = list(feat[field].unique()) 336 | for tru_val in unique_vals: 337 | user_dict[tru_val] = list(feat[feat[field] == tru_val][self.uid_field]) 338 | for val, proportion in distribution.items(): 339 | if val != 0.0: 340 | tru_val = self.field2token_id[field][val] 341 | for index, row in self.inter_feat.iterrows(): 342 | if row[self.uid_field] in user_dict[tru_val]: 343 | rand_val = random.uniform(0, 1) 344 | if rand_val <= proportion: 345 | self.inter_feat.drop(index, inplace=True) 346 | 347 | def create_distribution(self): 348 | """ 349 | 350 | Returns: 351 | 352 | """ 353 | dist_shift = self.config['distribution_shift'] 354 | if dist_shift is None: 355 | return [] 356 | 357 | for field in dist_shift: 358 | distribution_dict = dist_shift[field] 359 | # supports distribution dict of size 2 only 360 | assert (len(distribution_dict) == 2) 361 | if field not in self.field2type: 362 | raise ValueError(f'Field [{field}] not defined in dataset.') 363 | if sum(list(distribution_dict.values())) != 1: 364 | raise ValueError(f'Distribution needs to add up to 1.') 365 | if self.field2type[field] not in {FeatureType.TOKEN}: 366 | raise ValueError(f'Currently only works for Token types.') 367 | for feat_name in self.feat_name_list: 368 | feat = getattr(self, feat_name) 369 | if field in feat: 370 | user_val_dict = {} 371 | user_val_counts = {} 372 | user_val_original_proportions = {} 373 | unique_vals = list(feat[field].unique()) 374 | for val in unique_vals: 375 | user_val_dict[val] = list(feat[feat[field] == val][self.uid_field]) 376 | user_val_counts[val] = len( 377 | [i for i in self.inter_feat[self.uid_field] if i in user_val_dict[val]]) 378 | for val, proportion in distribution_dict.items(): 379 | if val != 0.0: 380 | token_val = self.field2token_id[field][val] 381 | user_val_original_proportions[val] = user_val_counts[token_val] / len(self.inter_feat) 382 | no_change_val = 0 383 | no_change_quantity = 0 384 | for val, proportion in distribution_dict.items(): 385 | token_val = self.field2token_id[field][val] 386 | if proportion >= user_val_original_proportions[val]: 387 | no_change_val = val 388 | no_change_new_proportion = proportion 389 | no_change_quantity = user_val_counts[token_val] 390 | num_new_test = int(no_change_quantity / no_change_new_proportion) 391 | num_other_class = num_new_test - no_change_quantity 392 | for val, proportion in distribution_dict.items(): 393 | token_val = self.field2token_id[field][val] 394 | if val != no_change_val: 395 | original_val = user_val_counts[token_val] 396 | drop_indices = np.random.choice( 397 | self.inter_feat.index[self.inter_feat[self.uid_field].isin(user_val_dict[token_val])], 398 | original_val - num_other_class, replace=False) 399 | self.inter_feat = self.inter_feat.drop(drop_indices) 400 | new_quantity = len( 401 | [i for i in self.inter_feat[self.uid_field] if i in user_val_dict[token_val]]) 402 | 403 | @staticmethod 404 | def create_distribution_slice(train, test): 405 | print("Preparing distributional test slice.") 406 | train.get_training_distribution_statistics() 407 | slice_test = copy.deepcopy(test) 408 | slice_test.create_distribution() 409 | # slice_test.get_training_distribution_statistics() 410 | # slice_test._filter_inter_by_user_or_item() 411 | slice_test._reset_index() 412 | slice_test._user_item_feat_preparation() 413 | return slice_test 414 | 415 | def get_training_distribution_statistics(self): 416 | """ 417 | 418 | Returns: 419 | 420 | """ 421 | dist_slice = self.config['distribution_shift'] 422 | if dist_slice is None: 423 | print("No Training Stats Computed") 424 | return [] 425 | 426 | for field in dist_slice: 427 | user_dict = {} 428 | for feat_name in self.feat_name_list: 429 | feat = getattr(self, feat_name) 430 | if field in feat: 431 | unique_vals = list(feat[field].unique()) 432 | for val in unique_vals: 433 | user_dict[val] = list(feat[feat[field] == val][self.uid_field]) 434 | dist = {} 435 | for val in user_dict: 436 | if val != 0.0: 437 | dist[val] = len(self.inter_feat[self.inter_feat[self.uid_field].isin(user_dict[val])]) 438 | print("Training Distribution:") 439 | for val in user_dict: 440 | if val != 0.0: 441 | print("Val: ", self.field2id_token[field][int(val)], "Percent: ", 442 | dist[val] / sum(list(dist.values()))) 443 | 444 | def get_attack_statistics(self, train): 445 | # TODO: add more statistics 446 | """ 447 | 448 | Args: 449 | train: 450 | 451 | Returns: 452 | 453 | """ 454 | print("Interaction Transformation Robustness Test Summary") 455 | 456 | def get_distribution_shift_statistics(self, train, test): 457 | print("Distribution Shift Robustness Test Summary") 458 | 459 | def get_transformation_statistics(self, test): 460 | # TODO: improve printed information 461 | print("Transformation of Features Robustness Test Summary") 462 | print("Original Test Size: ", len(test.inter_feat)) 463 | print("Original Test Users: ", len(test.inter_feat[self.uid_field].unique())) 464 | print("Original Test Features Distribution") 465 | 466 | print("Transformed Test Size: ", len(self.inter_feat)) 467 | print("Transformed Test Users: ", len(self.inter_feat[self.uid_field].unique())) 468 | print("Transformed Test Features Distribution") 469 | 470 | def get_sparsity_statistics(self, train): 471 | """ 472 | 473 | Args: 474 | train: 475 | 476 | Returns: 477 | 478 | """ 479 | print("Sparsity Robustness Test Summary") 480 | print("Original Train Size: ", len(train.inter_feat)) 481 | print("Original Train Users: ", len(train.inter_feat[self.uid_field].unique())) 482 | print("Sparsified Train Size: ", len(self.inter_feat)) 483 | print("Sparsified Train Users: ", len(self.inter_feat[self.uid_field].unique())) 484 | 485 | @staticmethod 486 | def create_transformed_test(test): 487 | """ 488 | 489 | Args: 490 | test: 491 | 492 | Returns: 493 | 494 | """ 495 | print("Preparing test set transformation.") 496 | transformed_test = copy.deepcopy(test) 497 | transformed_test.read_transform_features() 498 | transformed_test._transform_by_field_value_random() 499 | transformed_test._transform_by_field_value_structured() 500 | transformed_test._transform_by_field_value_delete_feat() 501 | transformed_test.get_transformation_statistics(test) 502 | return transformed_test 503 | 504 | @staticmethod 505 | def create_transformed_train(train): 506 | """ 507 | 508 | Returns: 509 | 510 | """ 511 | print("Preparing training set transformation.") 512 | transformed_train = copy.deepcopy(train) 513 | transformed_train.read_transform_interactions() 514 | transformed_train._transform_interactions_random() 515 | transformed_train.get_attack_statistics(train) 516 | return transformed_train 517 | 518 | def read_transform_interactions(self): 519 | transform_config = self.config.final_config_dict["transform_interactions"] 520 | 521 | if transform_config is None: 522 | print("No transformation configs.") 523 | return None 524 | 525 | if "fraction_transformed" in transform_config: 526 | self.config.final_config_dict["transform_inter"] = transform_config["fraction_transformed"] 527 | else: 528 | print("No transformation percent specified.") 529 | return None 530 | 531 | def read_sparsify(self): 532 | """ 533 | 534 | Returns: 535 | 536 | """ 537 | sparsify_config = self.config.final_config_dict["sparsify"] 538 | 539 | if sparsify_config is None: 540 | print("No sparsity configs.") 541 | return None 542 | 543 | if "min_user_inter" in sparsify_config: 544 | min_val = sparsify_config["min_user_inter"] 545 | self.config.final_config_dict['selected_user_spars_data'] = min_val 546 | else: 547 | self.config.final_config_dict['selected_user_spars_data'] = 0 548 | 549 | if "fraction_removed" in sparsify_config: 550 | fraction = sparsify_config["fraction_removed"] 551 | self.config.final_config_dict["fraction_spars_data"] = fraction 552 | else: 553 | print("No sparsity fraction specified.") 554 | return None 555 | 556 | @staticmethod 557 | def create_sparse_train(train): 558 | """ 559 | 560 | Args: 561 | train: 562 | 563 | Returns: 564 | 565 | """ 566 | print("Preparing sparsified training data set.") 567 | sparse_train = copy.deepcopy(train) 568 | sparse_train.read_sparsify() 569 | sparse_train._make_data_more_sparse() 570 | sparse_train.get_sparsity_statistics(train) 571 | return sparse_train 572 | 573 | def _filter_by_inter_num(self, train): 574 | """ 575 | Overloaded RecBole. This version calls adjusted version of _get_illegal_ids below. 576 | Args: 577 | train: 578 | 579 | Returns: 580 | 581 | """ 582 | ban_users = self._get_illegal_ids_by_inter_num(dataset=train, field=self.uid_field, feat=self.user_feat, 583 | max_num=self.config['max_user_inter_num'], 584 | min_num=self.config['min_user_inter_num']) 585 | ban_items = self._get_illegal_ids_by_inter_num(dataset=train, field=self.iid_field, feat=self.item_feat, 586 | max_num=self.config['max_item_inter_num'], 587 | min_num=self.config['min_item_inter_num']) 588 | 589 | if len(ban_users) == 0 and len(ban_items) == 0: 590 | return 591 | 592 | if self.user_feat is not None: 593 | dropped_user = self.user_feat[self.uid_field].isin(ban_users) 594 | self.user_feat.drop(self.user_feat.index[dropped_user], inplace=True) 595 | 596 | if self.item_feat is not None: 597 | dropped_item = self.item_feat[self.iid_field].isin(ban_items) 598 | self.item_feat.drop(self.item_feat.index[dropped_item], inplace=True) 599 | 600 | dropped_inter = pd.Series(False, index=self.inter_feat.index) 601 | if self.uid_field: 602 | dropped_inter |= self.inter_feat[self.uid_field].isin(ban_users) 603 | if self.iid_field: 604 | dropped_inter |= self.inter_feat[self.iid_field].isin(ban_items) 605 | self.logger.debug('[{}] dropped interactions'.format(len(dropped_inter))) 606 | self.inter_feat.drop(self.inter_feat.index[dropped_inter], inplace=True) 607 | 608 | def _get_illegal_ids_by_inter_num(self, dataset, field, feat, max_num=None, min_num=None): 609 | """ 610 | Overloaded from RecBole. This version uses *train* interactions for slicing. 611 | Args: 612 | field: 613 | feat: 614 | max_num: 615 | min_num: 616 | 617 | Returns: 618 | 619 | """ 620 | self.logger.debug('\n get_illegal_ids_by_inter_num:\n\t field=[{}], max_num=[{}], min_num=[{}]'.format( 621 | field, max_num, min_num 622 | )) 623 | 624 | if field is None: 625 | return set() 626 | if max_num is None and min_num is None: 627 | return set() 628 | 629 | max_num = max_num or np.inf 630 | min_num = min_num or -1 631 | 632 | ids = dataset[field].values 633 | inter_num = Counter(ids) 634 | ids = {id_ for id_ in inter_num if inter_num[id_] < min_num or inter_num[id_] > max_num} 635 | 636 | if feat is not None: 637 | for id_ in feat[field].values: 638 | if inter_num[id_] < min_num: 639 | ids.add(id_) 640 | self.logger.debug('[{}] illegal_ids_by_inter_num, field=[{}]'.format(len(ids), field)) 641 | return ids 642 | 643 | def _drop_by_value(self, val, cmp): 644 | """ 645 | Overloaded _drop_by_value function from RecBole Dataset base class. 646 | Here we enable filtering for any field type (not just floats). We also 647 | enable dropping of categorical features. This function is called by 648 | _filter_by_field_value() in RecBole. 649 | 650 | Args: 651 | val (dict): 652 | cmp (Callable): 653 | 654 | Returns: 655 | filter_field (list): field names used in comparison. 656 | 657 | """ 658 | 659 | if val is None: 660 | return [] 661 | 662 | self.logger.debug(set_color('drop_by_value', 'blue') + f': val={val}') 663 | filter_field = [] 664 | for field in val: 665 | if field not in self.field2type: 666 | raise ValueError(f'Field [{field}] not defined in dataset.') 667 | for feat_name in self.feat_name_list: 668 | feat = getattr(self, feat_name) 669 | if field in feat: 670 | if self.field2type[field] == FeatureType.TOKEN_SEQ: 671 | raise NotImplementedError 672 | if self.field2type[field] == FeatureType.TOKEN: 673 | # tokens are mapped to new values by __init__() 674 | if isinstance(val[field], str): 675 | feat.drop(feat.index[cmp(feat[field].values, self.field2token_id[field][val[field]])], 676 | inplace=True) 677 | else: 678 | def convert_to_orig_val(x): 679 | if int(x) == 0: 680 | return 0.0 681 | else: 682 | try: 683 | return float(self.field2id_token[field][int(x)]) 684 | except: 685 | return 0.0 686 | 687 | original_tokens = np.array([convert_to_orig_val(i) for i in feat[field].values]) 688 | feat.drop(feat.index[cmp(original_tokens, float(val[field]))], inplace=True) 689 | if self.field2type[field] in {FeatureType.FLOAT, FeatureType.FLOAT_SEQ}: 690 | feat.drop(feat.index[cmp(feat[field].values, val[field])], inplace=True) 691 | filter_field.append(field) 692 | return filter_field 693 | 694 | def get_slice_statistics(self, test): 695 | """ 696 | 697 | Args: 698 | slice_test: 699 | test: 700 | 701 | Returns: 702 | 703 | """ 704 | print("Slice Robustness Test Summary") 705 | print("Original Test Size: ", len(test.inter_feat)) 706 | print("Original Test Users: ", len(test.inter_feat[self.uid_field].unique())) 707 | print("Subpopulation Size: ", len(self.inter_feat)) 708 | print("Subpopulation Users: ", len(self.inter_feat[self.uid_field].unique())) 709 | 710 | def create_slice(self, test, train): 711 | slice_config = self.config.final_config_dict["slice"] 712 | slice_test = copy.deepcopy(test) 713 | print("Preparing subpopulation of Test set.") 714 | if "by_feature" in slice_config: 715 | slice_test = self.create_slice_by_feature(slice_test) 716 | if "by_inter" in slice_config: 717 | slice_test = self.create_slice_by_inter(slice_test, train) 718 | slice_test._reset_index() 719 | slice_test._user_item_feat_preparation() 720 | slice_test.get_slice_statistics(test) 721 | return slice_test 722 | 723 | def create_slice_by_inter(self, slice_test, train): 724 | print("Preparing test set slice based on training set interactions.") 725 | slice_test.read_slice_by_inter() 726 | slice_test._filter_by_inter_num(train) 727 | return slice_test 728 | 729 | def read_slice_by_inter(self): 730 | feature_config = self.config.final_config_dict["slice"]["by_inter"] 731 | 732 | if feature_config is None: 733 | print("No interaction subset specified.") 734 | return None 735 | 736 | if "user" in feature_config: 737 | user_inter = feature_config["user"] 738 | assert (type(user_inter) == dict) 739 | if "min" in user_inter: 740 | min_val = user_inter["min"] 741 | self.config.final_config_dict["min_user_inter_num"] = min_val 742 | if "max" in user_inter: 743 | max_val = user_inter["max"] 744 | self.config.final_config_dict["max_user_inter_num"] = max_val 745 | if "item" in feature_config: 746 | item_inter = feature_config["item"] 747 | assert (type(item_inter) == dict) 748 | if "min" in item_inter: 749 | min_val = item_inter["min"] 750 | self.config.final_config_dict["min_item_inter_num"] = min_val 751 | if "max" in item_inter: 752 | max_val = item_inter["max"] 753 | self.config.final_config_dict["max_item_inter_num"] = max_val 754 | 755 | def create_slice_by_feature(self, slice_test): 756 | print("Preparing test set slice based on feature values.") 757 | slice_test.read_slice_by_feature() 758 | slice_test._filter_by_field_value() 759 | slice_test._filter_inter_by_user_or_item() 760 | return slice_test 761 | 762 | def read_slice_by_feature(self): 763 | feature_config = self.config.final_config_dict["slice"]["by_feature"] 764 | 765 | if feature_config is None: 766 | print("No feature values specified.") 767 | return None 768 | 769 | for field in feature_config: 770 | for feat_name in self.feat_name_list: 771 | feat = getattr(self, feat_name) 772 | if field in feat: 773 | if field not in self.field2type: 774 | raise ValueError(f'Field [{field}] not defined in dataset.') 775 | slice_specs = feature_config[field] 776 | if type(slice_specs) == dict: 777 | if "min" in slice_specs: 778 | min_dict = {field: slice_specs["min"]} 779 | if self.config.final_config_dict["lowest_val"] is None: 780 | self.config.final_config_dict["lowest_val"] = min_dict 781 | else: 782 | self.config.final_config_dict["lowest_val"].update(min_dict) 783 | if "max" in slice_specs: 784 | max_dict = {field: slice_specs["max"]} 785 | if self.config.final_config_dict["highest_val"] is None: 786 | self.config.final_config_dict["highest_val"] = max_dict 787 | else: 788 | self.config.final_config_dict["highest_val"].update(max_dict) 789 | if "equal" in slice_specs: 790 | equal_dict = {field: slice_specs["equal"]} 791 | if self.config.final_config_dict["equal_val"] is None: 792 | self.config.final_config_dict["equal_val"] = equal_dict 793 | else: 794 | self.config.final_config_dict["equal_val"].update(equal_dict) 795 | else: 796 | print("Incorrect config format.") 797 | return None 798 | 799 | def read_transform_features(self): 800 | feature_config = self.config.final_config_dict["transform_features"] 801 | 802 | if feature_config is None: 803 | print("No feature transformation specified.") 804 | return None 805 | 806 | if "structured" in feature_config: 807 | self.config.final_config_dict['DropeFraction_or_variance_transform_val'] = {} 808 | for field in feature_config["structured"]: 809 | percent = feature_config["structured"][field] 810 | self.config.final_config_dict['DropeFraction_or_variance_transform_val'].update({field: percent}) 811 | elif "random" in feature_config: 812 | self.config.final_config_dict['transform_val'] = {} 813 | for field in feature_config["random"]: 814 | percent = feature_config["random"][field] 815 | self.config.final_config_dict['transform_val'].update({field: percent}) 816 | else: 817 | print("Transformation of features incorrectly specified.") 818 | return None 819 | 820 | def create_robustness_datasets(self, train, valid, test): 821 | """ 822 | Create the modified datasets needed for robustness tests according to robustness_dict configurations. 823 | Args: 824 | train (RobustnessGymDataset): 825 | valid (RobustnessGymDataset): 826 | test (RobustnessGymDataset): 827 | 828 | Returns: 829 | 830 | """ 831 | final_config = self.config.final_config_dict 832 | robustness_testing_datasets = {} 833 | 834 | if "slice" in final_config: 835 | robustness_testing_datasets["slice"] = self.create_slice(test, train) 836 | 837 | if "sparsify" in final_config: 838 | robustness_testing_datasets["sparsity"] = self.create_sparse_train(train) 839 | 840 | if "transform_features" in final_config: 841 | robustness_testing_datasets['transformation_test'] = self.create_transformed_test(test) 842 | 843 | if "transform_interactions" in final_config: 844 | robustness_testing_datasets['transformation_train'] = self.create_transformed_train(train) 845 | 846 | if "distribution_shift" in final_config: 847 | robustness_testing_datasets['distributional_slice'] = self.create_distribution_slice(train, test) 848 | 849 | return robustness_testing_datasets 850 | 851 | def build(self, eval_setting): 852 | """ 853 | Overloads RecBole build. Our version builds train, valid, test 854 | and modified versions of train, valid, test as needed according to the 855 | robustness tests requested in the robustness_dict. 856 | Args: 857 | eval_setting (EvalSetting): 858 | 859 | Returns: 860 | original_datasets (list): list containing original train, valid, test datasets 861 | robustness_testing_datasets (dict): {robustness test name: modified dataset} key value pairs 862 | 863 | """ 864 | if self.benchmark_filename_list is not None: 865 | raise NotImplementedError() 866 | 867 | ordering_args = eval_setting.ordering_args 868 | if ordering_args['strategy'] == 'shuffle': 869 | self.inter_feat = sk_shuffle(self.inter_feat) 870 | self.inter_feat = self.inter_feat.reset_index(drop=True) 871 | elif ordering_args['strategy'] == 'by': 872 | raise NotImplementedError() 873 | 874 | group_field = eval_setting.group_field 875 | split_args = eval_setting.split_args 876 | 877 | if split_args['strategy'] == 'by_ratio': 878 | original_datasets = self.split_by_ratio(split_args['ratios'], group_by=group_field) 879 | elif split_args['strategy'] == 'by_value': 880 | raise NotImplementedError() 881 | elif split_args['strategy'] == 'loo': 882 | original_datasets = self.leave_one_out(group_by=group_field, leave_one_num=split_args['leave_one_num']) 883 | else: 884 | original_datasets = self 885 | 886 | train, valid, test = original_datasets 887 | robustness_testing_datasets = self.create_robustness_datasets(train, valid, test) 888 | 889 | for data in list(robustness_testing_datasets.values()) + original_datasets: 890 | if data is not None: 891 | data.inter_feat = data.inter_feat.reset_index(drop=True) 892 | data._change_feat_format() 893 | if ordering_args['strategy'] == 'shuffle': 894 | torch.manual_seed(self.config['seed']) 895 | data.shuffle() 896 | elif ordering_args['strategy'] == 'by': 897 | data.sort(by=ordering_args['field'], ascending=ordering_args['ascending']) 898 | 899 | return original_datasets, robustness_testing_datasets 900 | 901 | 902 | if __name__ == '__main__': 903 | config = Config(model="DCN", dataset="ml-100k", 904 | config_dict={'distributional_slicing': {'gender': {"M": .9, "F": .1}}}) 905 | init_seed(config['seed'], config['reproducibility']) 906 | data = RobustnessGymDataset(config) 907 | datasets, robust_dict = data.build(EvalSetting(config)) 908 | print(robust_dict.keys()) 909 | -------------------------------------------------------------------------------- /RobustnessGymRecsys.py: -------------------------------------------------------------------------------- 1 | """ 2 | /* 3 | * Copyright (c) 2021, salesforce.com, inc. 4 | * All rights reserved. 5 | * SPDX-License-Identifier: BSD-3-Clause 6 | * For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 7 | */ 8 | """ 9 | from utils.GlobalVars import * 10 | from recbole.config import Config, EvalSetting 11 | from recbole.sampler import Sampler, RepeatableSampler, KGSampler 12 | from recbole.utils import ModelType, init_logger, get_model, get_trainer, init_seed, InputType 13 | from recbole.utils.utils import set_color 14 | from recbole.data.utils import get_data_loader 15 | from recbole.data import save_split_dataloaders 16 | from RobustnessGymDataset import RobustnessGymDataset 17 | from logging import getLogger, shutdown 18 | import importlib 19 | import pprint as pprint 20 | import pickle 21 | 22 | 23 | def create_dataset(config): 24 | """ 25 | Initializes RobustnessGymDataset for each recommendation system type in RecBole. 26 | Args: 27 | config (Config): Config file indicating MODEL_TYPE and model. 28 | 29 | Returns: 30 | RobustnessGymDataset instance. 31 | """ 32 | dataset_module = importlib.import_module('recbole.data.dataset') 33 | if hasattr(dataset_module, config['model'] + 'Dataset'): 34 | return getattr(dataset_module, config['model'] + 'Dataset')(config) 35 | else: 36 | model_type = config['MODEL_TYPE'] 37 | if model_type == ModelType.SEQUENTIAL: 38 | from recbole.data.dataset import SequentialDataset 39 | SequentialDataset.__bases__ = (RobustnessGymDataset,) 40 | return SequentialDataset(config) 41 | elif model_type == ModelType.KNOWLEDGE: 42 | from recbole.data.dataset import KnowledgeBasedDataset 43 | KnowledgeBasedDataset.__bases__ = (RobustnessGymDataset,) 44 | return KnowledgeBasedDataset(config) 45 | elif model_type == ModelType.SOCIAL: 46 | from recbole.data.dataset import SocialDataset 47 | SocialDataset.__bases__ = (RobustnessGymDataset,) 48 | return SocialDataset(config) 49 | elif model_type == ModelType.DECISIONTREE: 50 | from recbole.data.dataset import DecisionTreeDataset 51 | DecisionTreeDataset.__bases__ = (RobustnessGymDataset,) 52 | return DecisionTreeDataset(config) 53 | else: 54 | return RobustnessGymDataset(config) 55 | 56 | 57 | def get_transformed_train(config, train_kwargs, train_dataloader, robustness_testing_datasets): 58 | """ 59 | Converts training data set created by transformations into dataloader object. Uses same config 60 | settings as original training data. 61 | 62 | Args: 63 | train_kwargs (dict): Training dataset config 64 | train_dataloader (Dataloader): Training dataloader 65 | config (Config): General config 66 | robustness_testing_datasets (dict): Modified datasets resulting from robustness tests 67 | 68 | Returns: 69 | transformed_train (Dataloader) 70 | """ 71 | transformed_train = None 72 | if "transformation_train" in robustness_testing_datasets: 73 | transformation_kwargs = { 74 | 'config': config, 75 | 'dataset': robustness_testing_datasets['transformation_train'], 76 | 'batch_size': config['train_batch_size'], 77 | 'dl_format': config['MODEL_INPUT_TYPE'], 78 | 'shuffle': True, 79 | } 80 | try: 81 | transformation_kwargs['sampler'] = train_kwargs['sampler'] 82 | transformation_kwargs['neg_sample_args'] = train_kwargs['neg_sample_args'] 83 | transformed_train = train_dataloader(**transformation_kwargs) 84 | except: 85 | transformed_train = train_dataloader(**transformation_kwargs) 86 | 87 | return transformed_train 88 | 89 | 90 | def get_sparsity_train(config, train_kwargs, train_dataloader, robustness_testing_datasets): 91 | """ 92 | Converts training data set created by sparsity into dataloader object. Uses same config 93 | settings as original training data. 94 | 95 | Args: 96 | train_kwargs (dict): Training dataset config 97 | train_dataloader (Dataloader): Training dataloader 98 | config (Config): General config 99 | robustness_testing_datasets (dict): Modified datasets resulting from robustness tests 100 | 101 | Returns: 102 | sparsity_train (Dataloader) 103 | 104 | """ 105 | sparsity_train = None 106 | if "sparsity" in robustness_testing_datasets: 107 | sparsity_kwargs = { 108 | 'config': config, 109 | 'dataset': robustness_testing_datasets['sparsity'], 110 | 'batch_size': config['train_batch_size'], 111 | 'dl_format': config['MODEL_INPUT_TYPE'], 112 | 'shuffle': True, 113 | } 114 | try: 115 | sparsity_kwargs['sampler'] = train_kwargs['sampler'] 116 | sparsity_kwargs['neg_sample_args'] = train_kwargs['neg_sample_args'] 117 | sparsity_train = train_dataloader(**sparsity_kwargs) 118 | except: 119 | sparsity_train = train_dataloader(**sparsity_kwargs) 120 | 121 | return sparsity_train 122 | 123 | 124 | def get_distributional_slice_test(eval_kwargs, test_kwargs, test_dataloader, robustness_testing_datasets): 125 | """ 126 | 127 | Args: 128 | test_dataloader: 129 | test_kwargs: 130 | eval_kwargs (dict): 131 | test_dataloader (Dataloader): 132 | robustness_testing_datasets (dict): 133 | 134 | Returns: 135 | 136 | """ 137 | slice_test = None 138 | if 'distributional_slice' in robustness_testing_datasets: 139 | slice_kwargs = {'dataset': robustness_testing_datasets['distributional_slice']} 140 | if 'sampler' in test_kwargs: 141 | slice_kwargs['sampler'] = test_kwargs['sampler'] 142 | slice_kwargs.update(eval_kwargs) 143 | slice_test = test_dataloader(**slice_kwargs) 144 | 145 | return slice_test 146 | 147 | 148 | def get_slice_test(eval_kwargs, test_kwargs, test_dataloader, robustness_testing_datasets): 149 | """ 150 | 151 | Args: 152 | test_dataloader: 153 | test_kwargs: 154 | eval_kwargs (dict): 155 | test_dataloader (Dataloader): 156 | robustness_testing_datasets (dict): 157 | 158 | Returns: 159 | 160 | """ 161 | slice_test = None 162 | if 'slice' in robustness_testing_datasets: 163 | slice_kwargs = {'dataset': robustness_testing_datasets['slice']} 164 | if 'sampler' in test_kwargs: 165 | slice_kwargs['sampler'] = test_kwargs['sampler'] 166 | slice_kwargs.update(eval_kwargs) 167 | slice_test = test_dataloader(**slice_kwargs) 168 | 169 | return slice_test 170 | 171 | 172 | def get_transformation_test(eval_kwargs, test_kwargs, test_dataloader, robustness_testing_datasets): 173 | """ 174 | 175 | Args: 176 | test_dataloader: 177 | test_kwargs: 178 | eval_kwargs (dict): 179 | test_dataloader (Dataloader): 180 | robustness_testing_datasets (dict): 181 | 182 | Returns: 183 | 184 | """ 185 | transformation_test = None 186 | if 'transformation' in robustness_testing_datasets: 187 | transformation_kwargs = {'dataset': robustness_testing_datasets['transformation']} 188 | if 'sampler' in test_kwargs: 189 | transformation_kwargs['sampler'] = test_kwargs['sampler'] 190 | transformation_kwargs.update(eval_kwargs) 191 | transformation_test = test_dataloader(**transformation_kwargs) 192 | 193 | return transformation_test 194 | 195 | 196 | def data_preparation(config, dataset, save=False): 197 | """ 198 | Builds datasets, including datasets built by applying robustness tests, configures train, validation, test 199 | sets, converts to tensors. Overloads RecBole data_preparation - we include the preparation of the robustness test 200 | train/test/valid sets here. 201 | 202 | Args: 203 | config (Config): 204 | dataset (RobustnessGymDataset): 205 | save (bool): 206 | 207 | Returns: 208 | 209 | """ 210 | model_type = config['MODEL_TYPE'] 211 | model = config['model'] 212 | es = EvalSetting(config) 213 | 214 | original_datasets, robustness_testing_datasets = dataset.build(es) 215 | train_dataset, valid_dataset, test_dataset = original_datasets 216 | phases = ['train', 'valid', 'test'] 217 | sampler = None 218 | logger = getLogger() 219 | train_neg_sample_args = config['train_neg_sample_args'] 220 | eval_neg_sample_args = es.neg_sample_args 221 | 222 | # Training 223 | train_kwargs = { 224 | 'config': config, 225 | 'dataset': train_dataset, 226 | 'batch_size': config['train_batch_size'], 227 | 'dl_format': config['MODEL_INPUT_TYPE'], 228 | 'shuffle': True, 229 | } 230 | 231 | if train_neg_sample_args['strategy'] != 'none': 232 | if dataset.label_field in dataset.inter_feat: 233 | raise ValueError( 234 | f'`training_neg_sample_num` should be 0 ' 235 | f'if inter_feat have label_field [{dataset.label_field}].' 236 | ) 237 | if model_type != ModelType.SEQUENTIAL: 238 | sampler = Sampler(phases, original_datasets, train_neg_sample_args['distribution']) 239 | else: 240 | sampler = RepeatableSampler(phases, dataset, train_neg_sample_args['distribution']) 241 | if model not in ["MultiVAE", "MultiDAE", "MacridVAE", "CDAE", "ENMF", "RaCT", "RecVAE"]: 242 | train_kwargs['sampler'] = sampler.set_phase('train') 243 | train_kwargs['neg_sample_args'] = train_neg_sample_args 244 | if model_type == ModelType.KNOWLEDGE: 245 | kg_sampler = KGSampler(dataset, train_neg_sample_args['distribution']) 246 | train_kwargs['kg_sampler'] = kg_sampler 247 | 248 | dataloader = get_data_loader('train', config, train_neg_sample_args) 249 | logger.info( 250 | set_color('Build', 'pink') + set_color(f' [{dataloader.__name__}]', 'yellow') + ' for ' + 251 | set_color('[train]', 'yellow') + ' with format ' + set_color(f'[{train_kwargs["dl_format"]}]', 'yellow') 252 | ) 253 | if train_neg_sample_args['strategy'] != 'none': 254 | logger.info( 255 | set_color('[train]', 'pink') + set_color(' Negative Sampling', 'blue') + f': {train_neg_sample_args}' 256 | ) 257 | else: 258 | logger.info(set_color('[train]', 'pink') + set_color(' No Negative Sampling', 'yellow')) 259 | logger.info( 260 | set_color('[train]', 'pink') + set_color(' batch_size', 'cyan') + ' = ' + 261 | set_color(f'[{train_kwargs["batch_size"]}]', 'yellow') + ', ' + set_color('shuffle', 'cyan') + ' = ' + 262 | set_color(f'[{train_kwargs["shuffle"]}]\n', 'yellow') 263 | ) 264 | 265 | train_data = dataloader(**train_kwargs) 266 | transformed_train = get_transformed_train(config, train_kwargs, dataloader, robustness_testing_datasets) 267 | sparsity_train = get_sparsity_train(config, train_kwargs, dataloader, robustness_testing_datasets) 268 | 269 | # Evaluation 270 | eval_kwargs = { 271 | 'config': config, 272 | 'batch_size': config['eval_batch_size'], 273 | 'dl_format': InputType.POINTWISE, 274 | 'shuffle': False, 275 | } 276 | valid_kwargs = {'dataset': valid_dataset} 277 | test_kwargs = {'dataset': test_dataset} 278 | 279 | if eval_neg_sample_args['strategy'] != 'none': 280 | if dataset.label_field in dataset.inter_feat: 281 | raise ValueError( 282 | f'It can not validate with `{es.es_str[1]}` ' 283 | f'when inter_feat have label_field [{dataset.label_field}].' 284 | ) 285 | if sampler is None: 286 | if model_type != ModelType.SEQUENTIAL: 287 | sampler = Sampler(phases, original_datasets, eval_neg_sample_args['distribution']) 288 | else: 289 | sampler = RepeatableSampler(phases, dataset, eval_neg_sample_args['distribution']) 290 | else: 291 | sampler.set_distribution(eval_neg_sample_args['distribution']) 292 | eval_kwargs['neg_sample_args'] = eval_neg_sample_args 293 | valid_kwargs['sampler'] = sampler.set_phase('valid') 294 | test_kwargs['sampler'] = sampler.set_phase('test') 295 | 296 | valid_kwargs.update(eval_kwargs) 297 | test_kwargs.update(eval_kwargs) 298 | 299 | dataloader = get_data_loader('evaluation', config, eval_neg_sample_args) 300 | logger.info( 301 | set_color('Build', 'pink') + set_color(f' [{dataloader.__name__}]', 'yellow') + ' for ' + 302 | set_color('[evaluation]', 'yellow') + ' with format ' + set_color(f'[{eval_kwargs["dl_format"]}]', 'yellow') 303 | ) 304 | logger.info(es) 305 | logger.info( 306 | set_color('[evaluation]', 'pink') + set_color(' batch_size', 'cyan') + ' = ' + 307 | set_color(f'[{eval_kwargs["batch_size"]}]', 'yellow') + ', ' + set_color('shuffle', 'cyan') + ' = ' + 308 | set_color(f'[{eval_kwargs["shuffle"]}]\n', 'yellow') 309 | ) 310 | 311 | valid_data = dataloader(**valid_kwargs) 312 | test_data = dataloader(**test_kwargs) 313 | 314 | transformed_test = None 315 | if 'transformation_test' in robustness_testing_datasets: 316 | transformed_test_kwargs = test_kwargs 317 | transformed_test_kwargs['dataset'] = robustness_testing_datasets['transformation_test'] 318 | transformed_test = dataloader(**transformed_test_kwargs) 319 | 320 | slice_test = get_slice_test(eval_kwargs, test_kwargs, dataloader, robustness_testing_datasets) 321 | distributional_slice_test = get_distributional_slice_test(eval_kwargs, test_kwargs, dataloader, 322 | robustness_testing_datasets) 323 | 324 | if save: 325 | save_split_dataloaders(config, dataloaders=(train_data, valid_data, test_data)) 326 | 327 | robustness_testing_data = {'slice': slice_test, 328 | 'distributional_slice': distributional_slice_test, 329 | 'transformation_train': transformed_train, 330 | 'transformation_test': transformed_test, 331 | 'sparsity': sparsity_train} 332 | 333 | return train_data, valid_data, test_data, robustness_testing_data 334 | 335 | 336 | def get_config_dict(robustness_tests, base_config_dict): 337 | """ 338 | Combines robustness_test and train_config_dict into a single config_dict. 339 | 340 | Args: 341 | robustness_tests (dict): robustness test config dict 342 | base_config_dict (dict): train/data/eval/model/hyperparam config dict 343 | 344 | Returns: 345 | config_dict (dict): config dict 346 | """ 347 | config_dict = {} 348 | if robustness_tests is not None: 349 | if base_config_dict is not None: 350 | config_dict = {**robustness_tests, **base_config_dict} 351 | else: 352 | config_dict = robustness_tests 353 | else: 354 | if base_config_dict is not None: 355 | config_dict = base_config_dict 356 | return config_dict 357 | 358 | 359 | def train_and_test(model, dataset, robustness_tests=None, base_config_dict=None, save_model=True): 360 | """ 361 | Train a recommendation model and run robustness tests. 362 | Args: 363 | model (str): Name of model to be trained. 364 | dataset (str): Dataset name; must match the dataset's folder name located in 'data_path' path. 365 | base_config_dict: Configuration dictionary. If no config passed, takes default values. 366 | save_model (bool): Determines whether or not to externally save the model after training. 367 | robustness_tests (dict): Configuration dictionary for robustness tests. 368 | 369 | Returns: 370 | 371 | """ 372 | 373 | config_dict = get_config_dict(robustness_tests, base_config_dict) 374 | config = Config(model=model, dataset=dataset, config_dict=config_dict) 375 | init_seed(config['seed'], config['reproducibility']) 376 | 377 | logger = getLogger() 378 | if len(logger.handlers) != 0: 379 | logger.removeHandler(logger.handlers[1]) 380 | init_logger(config) 381 | 382 | logger.info(config) 383 | 384 | # dataset filtering 385 | dataset = create_dataset(config) 386 | logger.info(dataset) 387 | 388 | # dataset splitting 389 | train_data, valid_data, test_data, robustness_testing_data = data_preparation(config, dataset, save=True) 390 | 391 | for robustness_test in robustness_testing_data: 392 | if robustness_testing_data[robustness_test] is not None: 393 | logger.info(set_color('Robustness Test', 'yellow') + f': {robustness_test}') 394 | 395 | # model loading and initialization 396 | model = get_model(config['model'])(config, train_data).to(config['device']) 397 | logger.info(model) 398 | 399 | # trainer loading and initialization 400 | trainer = get_trainer(config['MODEL_TYPE'], config['model'])(config, model) 401 | 402 | # model training 403 | best_valid_score, best_valid_result = trainer.fit( 404 | train_data, valid_data, saved=save_model, show_progress=config['show_progress'] 405 | ) 406 | 407 | # model evaluation 408 | test_result = trainer.evaluate(test_data, load_best_model=save_model, 409 | show_progress=config['show_progress']) 410 | logger.info(set_color('best valid ', 'yellow') + f': {best_valid_result}') 411 | logger.info(set_color('test result', 'yellow') + f': {test_result}') 412 | 413 | test_result_transformation, test_result_sparsity, \ 414 | test_result_slice, test_result_distributional_slice = None, None, None, None 415 | 416 | if robustness_testing_data['slice'] is not None: 417 | test_result_slice = trainer.evaluate(robustness_testing_data['slice'], load_best_model=save_model, 418 | show_progress=config['show_progress']) 419 | logger.info(set_color('test result for slice', 'yellow') + f': {test_result_slice}') 420 | 421 | if robustness_testing_data['distributional_slice'] is not None: 422 | test_result_distributional_slice = trainer.evaluate(robustness_testing_data['distributional_slice'], 423 | load_best_model=save_model, 424 | show_progress=config['show_progress']) 425 | logger.info(set_color('test result for distributional slice', 'yellow') + f': ' 426 | f'{test_result_distributional_slice}') 427 | 428 | if robustness_testing_data['transformation_test'] is not None: 429 | test_result_transformation = trainer.evaluate(robustness_testing_data['transformation_test'], 430 | load_best_model=save_model, 431 | show_progress=config['show_progress']) 432 | logger.info(set_color('test result for transformation on test', 'yellow') + f': {test_result_transformation}') 433 | 434 | if robustness_testing_data['transformation_train'] is not None: 435 | transformation_model = get_model(config['model'])(config, robustness_testing_data['transformation_train']).to( 436 | config['device']) 437 | logger.info(transformation_model) 438 | transformation_trainer = get_trainer(config['MODEL_TYPE'], config['model'])(config, transformation_model) 439 | best_valid_score_transformation, best_valid_result_transformation = transformation_trainer.fit( 440 | robustness_testing_data['transformation_train'], valid_data, saved=save_model, 441 | show_progress=config['show_progress']) 442 | test_result_transformation = transformation_trainer.evaluate(test_data, load_best_model=save_model, 443 | show_progress=config['show_progress']) 444 | logger.info( 445 | set_color('best valid for transformed training set', 'yellow') + f': {best_valid_result_transformation}') 446 | logger.info(set_color('test result for transformed training set', 'yellow') + f': {test_result_transformation}') 447 | 448 | if robustness_testing_data['sparsity'] is not None: 449 | sparsity_model = get_model(config['model'])(config, robustness_testing_data['sparsity']).to(config['device']) 450 | logger.info(sparsity_model) 451 | sparsity_trainer = get_trainer(config['MODEL_TYPE'], config['model'])(config, sparsity_model) 452 | best_valid_score_sparsity, best_valid_result_sparsity = sparsity_trainer.fit( 453 | robustness_testing_data['sparsity'], valid_data, saved=save_model, 454 | show_progress=config['show_progress']) 455 | test_result_sparsity = sparsity_trainer.evaluate(test_data, load_best_model=save_model, 456 | show_progress=config['show_progress']) 457 | logger.info(set_color('best valid for sparsified training set', 'yellow') + f': {best_valid_result_sparsity}') 458 | logger.info(set_color('test result for sparsified training set', 'yellow') + f': {test_result_sparsity}') 459 | 460 | logger.handlers.clear() 461 | shutdown() 462 | del logger 463 | 464 | return { 465 | 'test_result': test_result, 466 | 'distributional_test_result': test_result_distributional_slice, 467 | 'transformation_test_result': test_result_transformation, 468 | 'sparsity_test_result': test_result_sparsity, 469 | 'slice_test_result': test_result_slice 470 | } 471 | 472 | 473 | def test(model, dataset, model_path, dataloader_path=None, robustness_tests=None, base_config_dict=None): 474 | """ 475 | Test a pre-trained model from file path. Note that the only robustness test applicable here 476 | is slicing. 477 | Args: 478 | model (str): Name of model. 479 | dataset (str): Name of dataset. 480 | model_path (str): Path to saved model. 481 | robustness_tests (dict): Configuration dictionary for robustness tests. 482 | base_config_dict (dict): Configuration dictionary for data/model/training/evaluation. 483 | 484 | Returns: 485 | 486 | """ 487 | config_dict = get_config_dict(robustness_tests, base_config_dict) 488 | config = Config(model=model, dataset=dataset, config_dict=config_dict) 489 | init_seed(config['seed'], config['reproducibility']) 490 | 491 | # logger initialization 492 | logger = getLogger() 493 | if len(logger.handlers) != 0: 494 | logger.removeHandler(logger.handlers[1]) 495 | init_logger(config) 496 | 497 | # dataset filtering 498 | dataset = create_dataset(config) 499 | logger.info(dataset) 500 | 501 | # dataset splitting 502 | if dataloader_path is None: 503 | train_data, _, test_data, robustness_testing_data = data_preparation(config, dataset, save=False) 504 | else: 505 | train_data, valid_data, test_data = pickle.load(open(SAVED_DIR + dataloader_path, "rb")) 506 | robustness_testing_data = {"slice": None, "transformation": None, "sparsity": None} 507 | 508 | # model loading and initialization 509 | model = get_model(config['model'])(config, train_data).to(config['device']) 510 | logger.info(model) 511 | 512 | # trainer loading and initialization 513 | trainer = get_trainer(config['MODEL_TYPE'], config['model'])(config, model) 514 | 515 | # model evaluation 516 | test_result = trainer.evaluate(test_data, load_best_model=True, model_file=model_path, 517 | show_progress=config['show_progress']) 518 | logger.info(set_color('test result', 'yellow') + f': {test_result}') 519 | 520 | test_result_slice = None 521 | if robustness_testing_data['slice'] is not None: 522 | test_result_slice = trainer.evaluate(robustness_testing_data['slice'], load_best_model=True, 523 | model_file=model_path, 524 | show_progress=config['show_progress']) 525 | logger.info(set_color('test result for slice', 'yellow') + f': {test_result_slice}') 526 | 527 | return { 528 | 'test_result': test_result, 529 | 'slice_test_result': test_result_slice 530 | } 531 | 532 | 533 | if __name__ == '__main__': 534 | all_results = {} 535 | for model in ["BPR"]: 536 | dataset = "ml-100k" 537 | base_config_dict = { 538 | 'data_path': DATASETS_DIR, 539 | 'show_progress': False, 540 | 'save_dataset': True, 541 | 'load_col': {'inter': ['user_id', 'item_id', 'rating', 'timestamp'], 542 | 'user': ['user_id', 'age', 'gender', 'occupation'], 543 | 'item': ['item_id', 'release_year', 'class']} 544 | } 545 | # robustness_dict = { 546 | # uncomment and add robustness test specifications here 547 | # } 548 | results = train_and_test(model=model, dataset=dataset, robustness_tests=robustness_dict, 549 | base_config_dict=base_config_dict) 550 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | ## Security 2 | 3 | Please report any security issue to [security@salesforce.com](mailto:security@salesforce.com) 4 | as soon as it is discovered. This library limits its runtime dependencies in 5 | order to reduce the total cost of ownership as much as can be, but all consumers 6 | should remain vigilant and have their security stakeholders review all third-party 7 | products (3PP) like this one and their dependencies. -------------------------------------------------------------------------------- /utils/GlobalVars.py: -------------------------------------------------------------------------------- 1 | """ 2 | /* 3 | * Copyright (c) 2021, salesforce.com, inc. 4 | * All rights reserved. 5 | * SPDX-License-Identifier: BSD-3-Clause 6 | * For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 7 | */ 8 | """ 9 | LOG_DIR = "./log" 10 | RESULTS_DIR = "./results" 11 | SAVED_DIR = "./saved" 12 | DATASETS_DIR = ".datasets/" 13 | 14 | GENERAL_MODELS = ["Pop", "ItemKNN", "BPR", "NeuMF", "ConvNCF", "DMF", "FISM", "NAIS", "SpectralCF", "GCMC", 15 | "NGCF", "LightGCN", "DGCF", "LINE", "MultiVAE", "MultiDAE", "MacridVAE", "CDAE", "ENMF", 16 | "NNCF", "RaCT", "RecVAE", "EASE", "SLIMElastic"] 17 | 18 | CONTEXT_MODELS = ["LR", "FM", "NFM", "DeepFM", "xDeepFM", "AFM", "FFM", "FwFM", "FNN", "PNN", "DSSM", "WideDeep", 19 | "DCN", "AutoInt"] 20 | 21 | KNOWLEDGE_MODELS = ["CKE", "CFKG", "KTUP", "KGAT", "RippleNet", "MKR", "KGCN", "KGNNLS"] 22 | 23 | SEQUENTIAL_MODELS = ["FPMC", "GRU4REC", "NARM", "STAMP", "Caser", "NextItNet", "TransRec", "SASRec", "BERT4Rec", 24 | "SRGNN", "GCSAN", "GRU4RecF", "SASRecF", "FDSA", "S3Rec", "GRU4RecKG", "KSR", "FOSSIL", 25 | "SHAN", "RepeatNet", "HGN", "HRM", "NPE"] 26 | --------------------------------------------------------------------------------