├── .gitignore ├── README.md ├── assets ├── coral.png └── slide.pdf ├── data └── .gitkeep ├── requirements.txt ├── run_inspired.sh ├── run_pearl.sh ├── run_redial.sh ├── src ├── __init__.py ├── dataset.py ├── metric.py ├── model.py └── trainer.py ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | outputs/ 2 | src/config/private.yaml 3 | src/notebook 4 | **/__pycache__/ 5 | *.log 6 | Dockerfile 7 | docker_run.sh 8 | data/* 9 | !data/.gitkeep 10 | wandb/ 11 | checkpoints/ 12 | 13 | # Created by https://www.toptal.com/developers/gitignore/api/python,visualstudiocode,pycharm,windows,linux,powershell,data,jupyternotebooks 14 | # Edit at https://www.toptal.com/developers/gitignore?templates=python,visualstudiocode,pycharm,windows,linux,powershell,data,jupyternotebooks 15 | 16 | ### Data ### 17 | *.csv 18 | *.dat 19 | *.efx 20 | *.gbr 21 | *.key 22 | *.pps 23 | *.ppt 24 | *.pptx 25 | *.sdf 26 | *.tax2010 27 | *.vcf 28 | *.xml 29 | 30 | ### JupyterNotebooks ### 31 | # gitignore template for Jupyter Notebooks 32 | # website: http://jupyter.org/ 33 | 34 | .ipynb_checkpoints 35 | */.ipynb_checkpoints/* 36 | 37 | # IPython 38 | profile_default/ 39 | ipython_config.py 40 | 41 | # Remove previous ipynb_checkpoints 42 | # git rm -r .ipynb_checkpoints/ 43 | 44 | ### Linux ### 45 | *~ 46 | 47 | # temporary files which can be created if a process still has a handle open of a deleted file 48 | .fuse_hidden* 49 | 50 | # KDE directory preferences 51 | .directory 52 | 53 | # Linux trash folder which might appear on any partition or disk 54 | .Trash-* 55 | 56 | # .nfs files are created when an open file is removed but is still being accessed 57 | .nfs* 58 | 59 | ### PowerShell ### 60 | # Exclude packaged modules 61 | *.zip 62 | 63 | # Exclude .NET assemblies from source 64 | *.dll 65 | 66 | ### PyCharm ### 67 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider 68 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 69 | 70 | # User-specific stuff 71 | .idea/**/workspace.xml 72 | .idea/**/tasks.xml 73 | .idea/**/usage.statistics.xml 74 | .idea/**/dictionaries 75 | .idea/**/shelf 76 | 77 | # AWS User-specific 78 | .idea/**/aws.xml 79 | 80 | # Generated files 81 | .idea/**/contentModel.xml 82 | 83 | # Sensitive or high-churn files 84 | .idea/**/dataSources/ 85 | .idea/**/dataSources.ids 86 | .idea/**/dataSources.local.xml 87 | .idea/**/sqlDataSources.xml 88 | .idea/**/dynamic.xml 89 | .idea/**/uiDesigner.xml 90 | .idea/**/dbnavigator.xml 91 | 92 | # Gradle 93 | .idea/**/gradle.xml 94 | .idea/**/libraries 95 | 96 | # Gradle and Maven with auto-import 97 | # When using Gradle or Maven with auto-import, you should exclude module files, 98 | # since they will be recreated, and may cause churn. Uncomment if using 99 | # auto-import. 100 | # .idea/artifacts 101 | # .idea/compiler.xml 102 | # .idea/jarRepositories.xml 103 | # .idea/modules.xml 104 | # .idea/*.iml 105 | # .idea/modules 106 | # *.iml 107 | # *.ipr 108 | 109 | # CMake 110 | cmake-build-*/ 111 | 112 | # Mongo Explorer plugin 113 | .idea/**/mongoSettings.xml 114 | 115 | # File-based project format 116 | *.iws 117 | 118 | # IntelliJ 119 | out/ 120 | 121 | # mpeltonen/sbt-idea plugin 122 | .idea_modules/ 123 | 124 | # JIRA plugin 125 | atlassian-ide-plugin.xml 126 | 127 | # Cursive Clojure plugin 128 | .idea/replstate.xml 129 | 130 | # SonarLint plugin 131 | .idea/sonarlint/ 132 | 133 | # Crashlytics plugin (for Android Studio and IntelliJ) 134 | com_crashlytics_export_strings.xml 135 | crashlytics.properties 136 | crashlytics-build.properties 137 | fabric.properties 138 | 139 | # Editor-based Rest Client 140 | .idea/httpRequests 141 | 142 | # Android studio 3.1+ serialized cache file 143 | .idea/caches/build_file_checksums.ser 144 | 145 | ### PyCharm Patch ### 146 | # Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721 147 | 148 | # *.iml 149 | # modules.xml 150 | # .idea/misc.xml 151 | # *.ipr 152 | 153 | # Sonarlint plugin 154 | # https://plugins.jetbrains.com/plugin/7973-sonarlint 155 | .idea/**/sonarlint/ 156 | 157 | # SonarQube Plugin 158 | # https://plugins.jetbrains.com/plugin/7238-sonarqube-community-plugin 159 | .idea/**/sonarIssues.xml 160 | 161 | # Markdown Navigator plugin 162 | # https://plugins.jetbrains.com/plugin/7896-markdown-navigator-enhanced 163 | .idea/**/markdown-navigator.xml 164 | .idea/**/markdown-navigator-enh.xml 165 | .idea/**/markdown-navigator/ 166 | 167 | # Cache file creation bug 168 | # See https://youtrack.jetbrains.com/issue/JBR-2257 169 | .idea/$CACHE_FILE$ 170 | 171 | # CodeStream plugin 172 | # https://plugins.jetbrains.com/plugin/12206-codestream 173 | .idea/codestream.xml 174 | 175 | # Azure Toolkit for IntelliJ plugin 176 | # https://plugins.jetbrains.com/plugin/8053-azure-toolkit-for-intellij 177 | .idea/**/azureSettings.xml 178 | 179 | ### Python ### 180 | # Byte-compiled / optimized / DLL files 181 | __pycache__/ 182 | *.py[cod] 183 | *$py.class 184 | 185 | # C extensions 186 | *.so 187 | 188 | # Distribution / packaging 189 | .Python 190 | build/ 191 | develop-eggs/ 192 | dist/ 193 | downloads/ 194 | eggs/ 195 | .eggs/ 196 | lib/ 197 | lib64/ 198 | parts/ 199 | sdist/ 200 | var/ 201 | wheels/ 202 | share/python-wheels/ 203 | *.egg-info/ 204 | .installed.cfg 205 | *.egg 206 | MANIFEST 207 | 208 | # PyInstaller 209 | # Usually these files are written by a python script from a template 210 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 211 | *.manifest 212 | *.spec 213 | 214 | # Installer logs 215 | pip-log.txt 216 | pip-delete-this-directory.txt 217 | 218 | # Unit test / coverage reports 219 | htmlcov/ 220 | .tox/ 221 | .nox/ 222 | .coverage 223 | .coverage.* 224 | .cache 225 | nosetests.xml 226 | coverage.xml 227 | *.cover 228 | *.py,cover 229 | .hypothesis/ 230 | .pytest_cache/ 231 | cover/ 232 | 233 | # Translations 234 | *.mo 235 | *.pot 236 | 237 | # Django stuff: 238 | *.log 239 | local_settings.py 240 | db.sqlite3 241 | db.sqlite3-journal 242 | 243 | # Flask stuff: 244 | instance/ 245 | .webassets-cache 246 | 247 | # Scrapy stuff: 248 | .scrapy 249 | 250 | # Sphinx documentation 251 | docs/_build/ 252 | 253 | # PyBuilder 254 | .pybuilder/ 255 | target/ 256 | 257 | # Jupyter Notebook 258 | 259 | # IPython 260 | 261 | # pyenv 262 | # For a library or package, you might want to ignore these files since the code is 263 | # intended to run in multiple environments; otherwise, check them in: 264 | # .python-version 265 | 266 | # pipenv 267 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 268 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 269 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 270 | # install all needed dependencies. 271 | #Pipfile.lock 272 | 273 | # poetry 274 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 275 | # This is especially recommended for binary packages to ensure reproducibility, and is more 276 | # commonly ignored for libraries. 277 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 278 | #poetry.lock 279 | 280 | # pdm 281 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 282 | #pdm.lock 283 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 284 | # in version control. 285 | # https://pdm.fming.dev/#use-with-ide 286 | .pdm.toml 287 | 288 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 289 | __pypackages__/ 290 | 291 | # Celery stuff 292 | celerybeat-schedule 293 | celerybeat.pid 294 | 295 | # SageMath parsed files 296 | *.sage.py 297 | 298 | # Environments 299 | .env 300 | .venv 301 | env/ 302 | venv/ 303 | ENV/ 304 | env.bak/ 305 | venv.bak/ 306 | 307 | # Spyder project settings 308 | .spyderproject 309 | .spyproject 310 | 311 | # Rope project settings 312 | .ropeproject 313 | 314 | # mkdocs documentation 315 | /site 316 | 317 | # mypy 318 | .mypy_cache/ 319 | .dmypy.json 320 | dmypy.json 321 | 322 | # Pyre type checker 323 | .pyre/ 324 | 325 | # pytype static type analyzer 326 | .pytype/ 327 | 328 | # Cython debug.sh symbols 329 | cython_debug/ 330 | 331 | # PyCharm 332 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 333 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 334 | # and can be added to the global gitignore or merged into this file. For a more nuclear 335 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 336 | #.idea/ 337 | 338 | ### Python Patch ### 339 | # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration 340 | poetry.toml 341 | 342 | # ruff 343 | .ruff_cache/ 344 | 345 | # LSP config files 346 | pyrightconfig.json 347 | 348 | ### VisualStudioCode ### 349 | .vscode/* 350 | !.vscode/settings.json 351 | !.vscode/tasks.json 352 | !.vscode/launch.json 353 | !.vscode/extensions.json 354 | !.vscode/*.code-snippets 355 | 356 | # Local History for Visual Studio Code 357 | .history/ 358 | 359 | # Built Visual Studio Code Extensions 360 | *.vsix 361 | 362 | ### VisualStudioCode Patch ### 363 | # Ignore all local history of files 364 | .history 365 | .ionide 366 | 367 | ### Windows ### 368 | # Windows thumbnail cache files 369 | Thumbs.db 370 | Thumbs.db:encryptable 371 | ehthumbs.db 372 | ehthumbs_vista.db 373 | 374 | # Dump file 375 | *.stackdump 376 | 377 | # Folder config file 378 | [Dd]esktop.ini 379 | 380 | # Recycle Bin used on file shares 381 | $RECYCLE.BIN/ 382 | 383 | # Windows Installer files 384 | *.cab 385 | *.msi 386 | *.msix 387 | *.msm 388 | *.msp 389 | 390 | # Windows shortcuts 391 | *.lnk 392 | 393 | 394 | ### macOS ### 395 | # General 396 | .DS_Store 397 | .AppleDouble 398 | .LSOverride 399 | 400 | # Icon must end with two \r 401 | Icon 402 | 403 | 404 | # Thumbnails 405 | ._* 406 | 407 | # Files that might appear in the root of a volume 408 | .DocumentRevisions-V100 409 | .fseventsd 410 | .Spotlight-V100 411 | .TemporaryItems 412 | .Trashes 413 | .VolumeIcon.icns 414 | .com.apple.timemachine.donotpresent 415 | 416 | # Directories potentially created on remote AFP share 417 | .AppleDB 418 | .AppleDesktop 419 | Network Trash Folder 420 | Temporary Items 421 | .apdisk 422 | 423 | ### macOS Patch ### 424 | # iCloud generated files 425 | *.icloud 426 | 427 | # End of https://www.toptal.com/developers/gitignore/api/python,visualstudiocode,pycharm,windows,linux,powershell,data,jupyternotebooks -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🪸 Empowering Retrieval-based Conversational Recommendation with Contrasting User Preferences (NAACL'25) 2 | 3 | This is the official repository for the NAACL 2025 paper: 4 | [Empowering Retrieval-based Conversational Recommendation with Contrasting User Preferences](https://arxiv.org/abs/2503.22005) 5 | 6 | ## 🧠 Overview 7 | 8 | 🪸CORAL is a retrieval-based CRS framework that explicitly represents and models the user, item, and contrasting preferences. 9 | 10 | ![Overview of CORAL](assets/coral.png) 11 | 12 | For more details, please refer to our [📄paper](https://arxiv.org/abs/2503.22005), [🛝slide](assets/slide.pdf), and [🌐blog(Korean)](https://dial.skku.edu/blog/2025_coral). 13 | 14 | ## 📦 Installation 15 | 16 | ``` 17 | git clone https://github.com/kookeej/CORAL.git 18 | cd CORAL 19 | conda create -n coral python=3.10 20 | conda activate coral 21 | pip install -r requirements.txt 22 | ``` 23 | 24 | ## 🗃️ Dataset 25 | The datasets used in this work are available on [Hugging Face](https://huggingface.co/datasets/kookeej/CORAL). 26 | 27 | Please download and extract the datasets into the data/ folder as follows: 28 | 29 | ``` 30 | data/ 31 | ├── pearl/ 32 | ├── inspired/ 33 | └── redial/ 34 | ``` 35 | 36 | ## 🚀 Training 37 | Run the training scripts for each dataset: 38 | 39 | ``` 40 | # Pearl 41 | bash run_pearl.sh 42 | 43 | # Inspired 44 | bash run_inspired.sh 45 | 46 | # Redial 47 | bash run_redial.sh 48 | ``` 49 | 50 | ## 🧾 Citation 51 | If you find this work helpful, please cite our paper: 52 | ``` 53 | @inproceedings{kook-etal-2025-empowering, 54 | title = "Empowering Retrieval-based Conversational Recommendation with Contrasting User Preferences", 55 | author = "Kook, Heejin and 56 | Kim, Junyoung and 57 | Park, Seongmin and 58 | Lee, Jongwuk", 59 | booktitle = "Proceedings of the 2025 Conference of the Nations of the Americas Chapter of the Association for Computational Linguistics: Human Language Technologies (Volume 1: Long Papers)", 60 | month = apr, 61 | year = "2025", 62 | address = "Albuquerque, New Mexico", 63 | publisher = "Association for Computational Linguistics", 64 | url = "https://aclanthology.org/2025.naacl-long.392/", 65 | pages = "7692--7707", 66 | ISBN = "979-8-89176-189-6", 67 | } 68 | 69 | ``` 70 | -------------------------------------------------------------------------------- /assets/coral.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kookeej/CORAL/3002077452725c5f4ea63af70bdd98be03f93529/assets/coral.png -------------------------------------------------------------------------------- /assets/slide.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kookeej/CORAL/3002077452725c5f4ea63af70bdd98be03f93529/assets/slide.pdf -------------------------------------------------------------------------------- /data/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kookeej/CORAL/3002077452725c5f4ea63af70bdd98be03f93529/data/.gitkeep -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers~=4.43.3 2 | accelerate~=1.6.0 3 | wandb 4 | wonderwords 5 | peft 6 | datasets 7 | einops -------------------------------------------------------------------------------- /run_inspired.sh: -------------------------------------------------------------------------------- 1 | python train.py \ 2 | --data_name inspired \ 3 | --alpha 0.5 \ 4 | --beta 0.2 \ 5 | --query_used_info c l d \ 6 | --doc_used_info m pref \ 7 | --bf16 \ 8 | --negative_sample 16 \ 9 | --learning_rate 1e-4 \ 10 | --batch_size 10 \ 11 | --wandb_project CORAL # (Optional) Remove this if you are not using wandb. -------------------------------------------------------------------------------- /run_pearl.sh: -------------------------------------------------------------------------------- 1 | python train.py \ 2 | --data_name pearl \ 3 | --alpha 0.5 \ 4 | --beta 0.3 \ 5 | --query_used_info c l d \ 6 | --doc_used_info m pref \ 7 | --bf16 \ 8 | --negative_sample 24 \ 9 | --learning_rate 5e-5 \ 10 | --batch_size 8 \ 11 | --patienc 3 \ 12 | --wandb_project CORAL # (Optional) Remove this if you are not using wandb. -------------------------------------------------------------------------------- /run_redial.sh: -------------------------------------------------------------------------------- 1 | python train.py \ 2 | --data_name redial \ 3 | --alpha 0.5 \ 4 | --beta 0.1 \ 5 | --query_used_info c l d \ 6 | --doc_used_info m pref \ 7 | --bf16 \ 8 | --negative_sample 16 \ 9 | --learning_rate 5e-5 \ 10 | --batch_size 10 \ 11 | --wandb_project CORAL # (Optional) Remove this if you are not using wandb. -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | from .trainer import CoralTrainer 2 | from .dataset import CoralDataset, CoralItemDataset 3 | from .model import CoralEncoder -------------------------------------------------------------------------------- /src/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from mpmath import convert 3 | from torch.utils.data import Dataset 4 | from transformers import AutoTokenizer 5 | from typing import List, Optional, Tuple 6 | 7 | 8 | class CoralItemDataset(Dataset): 9 | def __init__( 10 | self, 11 | item_text: dict[str, List[str]], 12 | base_tokenizer_name: Optional[str] = 'nvidia/NV-Embed-v1', 13 | max_length: Optional[int] = 512, 14 | truncation_side: Optional[str] = 'right', 15 | ): 16 | super().__init__() 17 | self.item_texts = item_text['item'] if 'item' in item_text else None 18 | 19 | self.max_length = max_length 20 | self.tokenizer = AutoTokenizer.from_pretrained(base_tokenizer_name, truncation_side=truncation_side) 21 | 22 | def __len__(self): 23 | return len(self.item_texts) 24 | 25 | def __getitem__(self, idx): 26 | tokenized_i = self.tokenizer( 27 | self.item_texts[idx], 28 | max_length=self.max_length, 29 | truncation=True, 30 | return_tensors='pt', 31 | add_special_tokens=True 32 | ) 33 | 34 | new_tokenized_i = dict() 35 | if tokenized_i is not None: 36 | for key, value in tokenized_i.items(): 37 | new_tokenized_i[f'item_{key}'] = value 38 | 39 | new_tokenized_i = {key: val.squeeze(0) for key, val in new_tokenized_i.items() if val is not None} 40 | return new_tokenized_i 41 | 42 | def collate_fn(self, batch): 43 | item_input_ids, item_attention_mask = [], [] 44 | 45 | for item in batch: 46 | item_input_ids.append(item['item_input_ids']) 47 | item_attention_mask.append(item['item_attention_mask']) 48 | 49 | # pad into max length 50 | item_input_ids = torch.nn.utils.rnn.pad_sequence(item_input_ids, batch_first=True, 51 | padding_value=self.tokenizer.pad_token_id) 52 | item_attention_mask = torch.nn.utils.rnn.pad_sequence(item_attention_mask, batch_first=True, 53 | padding_value=0) 54 | 55 | return { 56 | 'item_input_ids': item_input_ids if self.item_texts else None, 57 | 'item_attention_mask': item_attention_mask if self.item_texts else None, 58 | } 59 | 60 | 61 | class CoralDataset(Dataset): 62 | def __init__( 63 | self, 64 | user_text: dict[str, List[str]], 65 | gt_ids: List[int]|List[List[int]], 66 | base_tokenizer_name: Optional[str] = 'nvidia/NV-Embed-v1', 67 | max_length: Optional[int] = 512, 68 | truncation_side: Optional[str] = 'right', 69 | ): 70 | super().__init__() 71 | self.conv_texts = user_text['c'] if 'c' in user_text else None 72 | self.like_texts = user_text['l'] if 'l' in user_text else None 73 | self.dislike_texts = user_text['d'] if 'd' in user_text else None 74 | 75 | self.gt_ids = gt_ids 76 | self.max_length = max_length 77 | self.tokenizer = AutoTokenizer.from_pretrained(base_tokenizer_name, truncation_side=truncation_side) 78 | 79 | self.len = len(self.conv_texts) 80 | 81 | def __len__(self): 82 | return self.len 83 | 84 | def __getitem__(self, idx): 85 | tokenized_conv = self.tokenizer( 86 | self.conv_texts[idx], 87 | max_length=self.max_length, 88 | truncation=True, 89 | return_tensors='pt', 90 | add_special_tokens=True 91 | ) if self.conv_texts is not None else None 92 | 93 | tokenized_like = self.tokenizer( 94 | self.like_texts[idx], 95 | max_length=self.max_length // 2, 96 | truncation=True, 97 | return_tensors='pt', 98 | add_special_tokens=True 99 | ) if self.like_texts is not None else None 100 | 101 | tokenized_dislike = self.tokenizer( 102 | self.dislike_texts[idx], 103 | max_length=self.max_length // 2, 104 | truncation=True, 105 | return_tensors='pt', 106 | add_special_tokens=True 107 | ) if self.dislike_texts is not None else None 108 | 109 | tokenized_u = dict() 110 | if tokenized_conv is not None: 111 | for key, value in tokenized_conv.items(): 112 | tokenized_u[f'conv_{key}'] = value 113 | 114 | if tokenized_like is not None: 115 | for key, value in tokenized_like.items(): 116 | tokenized_u[f'like_{key}'] = value 117 | 118 | if tokenized_dislike is not None: 119 | for key, value in tokenized_dislike.items(): 120 | tokenized_u[f'dislike_{key}'] = value 121 | 122 | tokenized_u = {key: val.squeeze(0) for key, val in tokenized_u.items() if val is not None} 123 | gt_id = self.gt_ids[idx] 124 | 125 | return idx, gt_id, tokenized_u 126 | 127 | def collate_fn(self, batch): 128 | batch_user_id = [idx for idx, _, _ in batch] 129 | batch_gt_id = [idx for _, idx, _ in batch] 130 | if self.conv_texts: 131 | conv_input_ids, conv_attention_mask = [], [] 132 | if self.like_texts: 133 | like_input_ids, like_attention_mask = [], [] 134 | if self.dislike_texts: 135 | dislike_input_ids, dislike_attention_mask = [], [] 136 | 137 | for _, _, tokenized_user in batch: 138 | if self.conv_texts: 139 | conv_input_ids.append(tokenized_user['conv_input_ids']) 140 | conv_attention_mask.append(tokenized_user['conv_attention_mask']) 141 | if self.like_texts: 142 | like_input_ids.append(tokenized_user['like_input_ids']) 143 | like_attention_mask.append(tokenized_user['like_attention_mask']) 144 | if self.dislike_texts: 145 | dislike_input_ids.append(tokenized_user['dislike_input_ids']) 146 | dislike_attention_mask.append(tokenized_user['dislike_attention_mask']) 147 | 148 | # pad into max length 149 | if self.conv_texts: 150 | conv_input_ids = torch.nn.utils.rnn.pad_sequence(conv_input_ids, batch_first=True, 151 | padding_value=self.tokenizer.pad_token_id) 152 | conv_attention_mask = torch.nn.utils.rnn.pad_sequence(conv_attention_mask, batch_first=True, 153 | padding_value=0) 154 | if self.like_texts: 155 | like_input_ids = torch.nn.utils.rnn.pad_sequence(like_input_ids, batch_first=True, 156 | padding_value=self.tokenizer.pad_token_id) 157 | like_attention_mask = torch.nn.utils.rnn.pad_sequence(like_attention_mask, batch_first=True, 158 | padding_value=0) 159 | if self.dislike_texts: 160 | dislike_input_ids = torch.nn.utils.rnn.pad_sequence(dislike_input_ids, batch_first=True, 161 | padding_value=self.tokenizer.pad_token_id) 162 | dislike_attention_mask = torch.nn.utils.rnn.pad_sequence(dislike_attention_mask, batch_first=True, 163 | padding_value=0) 164 | return batch_user_id, batch_gt_id, { 165 | 'conv_input_ids': conv_input_ids if self.conv_texts else None, 166 | 'conv_attention_mask': conv_attention_mask if self.conv_texts else None, 167 | 'like_input_ids': like_input_ids if self.like_texts else None, 168 | 'like_attention_mask': like_attention_mask if self.like_texts else None, 169 | 'dislike_input_ids': dislike_input_ids if self.dislike_texts else None, 170 | 'dislike_attention_mask': dislike_attention_mask if self.dislike_texts else None 171 | } 172 | -------------------------------------------------------------------------------- /src/metric.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def recall_at_k(true_labels, predicted_labels, k): 4 | """ 5 | Compute Recall@k 6 | 7 | Args: 8 | - true_labels (list): A list where each element is a list of relevant document indices for a query. 9 | - predicted_labels (list): A list where each element is a list of predicted document indices for a query. 10 | - k (int): The number of top documents to consider for the metric. 11 | 12 | Returns: 13 | - recall (float): The average Recall@k score over all queries. 14 | """ 15 | recalls = [] 16 | 17 | for true, pred in zip(true_labels, predicted_labels): 18 | if type(true) == str: 19 | true = [true] 20 | if type(true) == int: 21 | true = [true] 22 | 23 | pred_k = pred[:k] 24 | if not true: 25 | recalls.append(0.0) 26 | else: 27 | recall = len(set(pred_k) & set(true)) / len(true) 28 | recalls.append(recall) 29 | return np.mean(recalls) 30 | 31 | 32 | def ndcg_at_k(true_labels, predicted_labels, k): 33 | """ 34 | Compute NDCG@k 35 | 36 | Args: 37 | - true_labels (list): A list where each element is a list of relevant document indices for a query. 38 | - predicted_labels (list): A list where each element is a list of predicted document indices for a query. 39 | - k (int): The number of top documents to consider for the metric. 40 | 41 | Returns: 42 | - ndcg (float): The average NDCG@k score over all queries. 43 | """ 44 | 45 | def dcg(rel_scores): 46 | rel_scores = np.array(rel_scores) 47 | discounts = np.log2(np.arange(len(rel_scores)) + 2) 48 | return np.sum((2 ** rel_scores - 1) / discounts) 49 | 50 | def idcg(n_relevant): 51 | # The ideal DCG is obtained by taking the highest possible relevance scores 52 | ideal_rel = [1] * n_relevant + [0] * (k - n_relevant) 53 | return dcg(ideal_rel) 54 | 55 | ndcgs = [] 56 | for true, pred in zip(true_labels, predicted_labels): 57 | if type(true) == str: 58 | true = [true] 59 | if type(true) == int: 60 | true = [true] 61 | 62 | rel_scores = [(1 if p in true else 0) for p in pred[:k]] 63 | actual_idcg = idcg(min(len(true), k)) 64 | actual_dcg = dcg(rel_scores) 65 | if actual_idcg == 0: 66 | ndcgs.append(0.0) 67 | else: 68 | ndcgs.append(actual_dcg / actual_idcg) 69 | return np.mean(ndcgs) 70 | 71 | 72 | -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import numpy as np 4 | from mpmath import convert 5 | from peft import LoraConfig, LoraModel, get_peft_model, PeftModel 6 | 7 | from tqdm import tqdm 8 | from typing import Optional, List, Tuple, Union 9 | from pathlib import Path 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | from transformers import ( 15 | AutoTokenizer, 16 | AutoModel 17 | ) 18 | 19 | 20 | def _mean_pooling(output, attention_mask): 21 | input_mask_expanded = attention_mask.unsqueeze(-1).expand_as(output).float() # same with sentence_embeddings 22 | sum_output = torch.sum(output * input_mask_expanded, 1) 23 | sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) 24 | mean_output = sum_output / sum_mask 25 | 26 | return mean_output 27 | 28 | 29 | class CoralEncoder(nn.Module): 30 | def __init__(self, base_model_name: str, alpha:float, beta:float): 31 | super().__init__() 32 | self.model = AutoModel.from_pretrained(base_model_name, trust_remote_code=True, torch_dtype=torch.bfloat16) 33 | self.alpha, self.beta = alpha, beta 34 | 35 | self.configure_lora() 36 | 37 | def configure_lora(self): 38 | for name, param in self.model.named_parameters(): 39 | param.requires_grad = False 40 | target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"] 41 | config = LoraConfig( 42 | target_modules=target_modules, 43 | r=16, 44 | lora_alpha=32, 45 | lora_dropout=0.1 46 | ) 47 | self.model = get_peft_model(self.model, config) 48 | self.model.print_trainable_parameters() 49 | 50 | 51 | def forward( 52 | self, 53 | item_input_ids: Optional[torch.Tensor] = None, 54 | item_attention_mask: Optional[torch.Tensor] = None, 55 | conv_input_ids: Optional[torch.Tensor] = None, 56 | conv_attention_mask: Optional[torch.Tensor] = None, 57 | like_input_ids: Optional[torch.Tensor] = None, 58 | like_attention_mask: Optional[torch.Tensor] = None, 59 | dislike_input_ids: Optional[torch.Tensor] = None, 60 | dislike_attention_mask: Optional[torch.Tensor] = None, 61 | ): 62 | 63 | item_output = self.model( 64 | input_ids=item_input_ids, 65 | attention_mask=item_attention_mask, 66 | ) if item_input_ids is not None else None 67 | 68 | conv_output = self.model( 69 | input_ids=conv_input_ids, 70 | attention_mask=conv_attention_mask, 71 | ) if conv_input_ids is not None else None 72 | 73 | like_output = self.model( 74 | input_ids=like_input_ids, 75 | attention_mask=like_attention_mask, 76 | ) if like_input_ids is not None else None 77 | 78 | dislike_output = self.model( 79 | input_ids=dislike_input_ids, 80 | attention_mask=dislike_attention_mask, 81 | ) if dislike_input_ids is not None else None 82 | 83 | item_embedding = _mean_pooling(item_output['sentence_embeddings'], 84 | attention_mask=item_attention_mask) if item_output is not None else None 85 | conv_embedding = _mean_pooling(conv_output['sentence_embeddings'], 86 | attention_mask=conv_attention_mask) if conv_output is not None else None 87 | like_embedding = _mean_pooling(like_output['sentence_embeddings'], 88 | attention_mask=like_attention_mask) if like_output is not None else None 89 | dislike_embedding = _mean_pooling(dislike_output['sentence_embeddings'], 90 | attention_mask=dislike_attention_mask) if dislike_output is not None else None 91 | 92 | item_embedding = F.normalize(item_embedding, p=2, dim=-1) if item_embedding is not None else None 93 | conv_embedding = F.normalize(conv_embedding, p=2, dim=-1) if conv_embedding is not None else None 94 | like_embedding = F.normalize(like_embedding, p=2, dim=-1) if like_embedding is not None else None 95 | dislike_embedding = F.normalize(dislike_embedding, p=2, 96 | dim=-1) if dislike_embedding is not None else None 97 | 98 | return { 99 | 'item_embedding': item_embedding, 100 | 'conv_embedding': conv_embedding, 101 | 'like_embedding': like_embedding, 102 | 'dislike_embedding': dislike_embedding, 103 | } 104 | 105 | def save_checkpoint(self, ckpt_save_path): 106 | if not os.path.exists(ckpt_save_path): 107 | os.makedirs(ckpt_save_path) 108 | self.model.save_pretrained(Path(ckpt_save_path)) 109 | 110 | 111 | def load_best_checkpoint(self, ckpt_save_path): 112 | self.model.load_adapter(Path(ckpt_save_path), adapter_name='lora') 113 | 114 | -------------------------------------------------------------------------------- /src/trainer.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Literal, List, Dict 2 | 3 | import numpy as np 4 | import torch 5 | import wandb 6 | from torch import nn 7 | from tqdm import tqdm 8 | 9 | torch.set_float32_matmul_precision("medium") 10 | import torch.optim as optim 11 | from torch.utils.data import DataLoader 12 | import transformers 13 | from torch.amp import autocast 14 | 15 | from .metric import ( 16 | recall_at_k, 17 | ndcg_at_k 18 | ) 19 | 20 | class CoralTrainer: 21 | def __init__( 22 | self, 23 | model, 24 | train_dataloader: DataLoader, 25 | valid_dataloader: DataLoader, 26 | test_dataloader: DataLoader, 27 | train_infer_dataloader: DataLoader, 28 | item_dataloader: DataLoader, 29 | item_num: int, 30 | train_conv_gt_ids: List[List], 31 | train_gt_ids: List, 32 | valid_gt_ids: List, 33 | test_gt_ids: List, 34 | optimizer: Literal['Adam', 'AdamW'], 35 | scheduler: Literal[ 36 | 'get_linear_schedule_with_warmup', 'get_constant_schedule_with_warmup', 'get_cosine_schedule_with_warmup'], 37 | accumulation_steps: int, 38 | training_args: Dict, 39 | wandb_logger: wandb.sdk.wandb_run.Run | None = None 40 | ): 41 | 42 | self.model = model 43 | self.train_dataloader = train_dataloader 44 | self.valid_dataloader = valid_dataloader 45 | self.test_dataloader = test_dataloader 46 | self.train_infer_dataloader = train_infer_dataloader 47 | self.item_dataloader = item_dataloader 48 | 49 | self.train_conv_gt_ids = train_conv_gt_ids 50 | self.train_conv_gt_mask = torch.zeros(len(train_conv_gt_ids), item_num, dtype=torch.bool) 51 | for idx, doc_ids in enumerate(train_conv_gt_ids): 52 | self.train_conv_gt_mask[idx, doc_ids] = True 53 | 54 | self.train_gt_ids = train_gt_ids 55 | self.valid_gt_ids = valid_gt_ids 56 | self.test_gt_ids = test_gt_ids 57 | 58 | self.optimizer = getattr(optim, optimizer) 59 | self.scheduler = getattr(transformers, scheduler) 60 | self.accumulation_steps = accumulation_steps 61 | 62 | self.training_args = training_args 63 | self.wandb_logger = wandb_logger 64 | 65 | self.loss_fn = nn.CrossEntropyLoss() 66 | 67 | self.item_embeddings = None 68 | self.negative_samples = None 69 | 70 | def _set_optimizer(self, learning_rate: float): 71 | return self.optimizer(self.model.parameters(), lr=learning_rate) 72 | 73 | def _set_scheduler(self, optimizer, num_warmup_steps: Optional[int] = None, 74 | num_training_steps: Optional[int] = None): 75 | if num_warmup_steps is None: 76 | num_warmup_steps = len(self.train_dataloader) * 0.1 77 | if num_training_steps is None: 78 | num_training_steps = len(self.train_dataloader) * self.training_args['epochs'] 79 | return self.scheduler(optimizer, num_training_steps, num_warmup_steps) 80 | 81 | @torch.no_grad() 82 | def _encode_items(self, item_dataloader, device): 83 | self.model.eval() 84 | self.model.to(device) 85 | with (autocast(dtype=torch.bfloat16, enabled=self.training_args['bf16'], device_type=device)): 86 | item_encoded = [] 87 | for batch in tqdm(item_dataloader, desc="Encoding Items..."): 88 | inputs = {key: val.squeeze(1).to(device) if val is not None else val for key, val in batch.items()} 89 | outputs = self.model(**inputs) 90 | item_encoded.extend(outputs[f'item_embedding'].cpu()) 91 | return torch.stack(item_encoded, dim=0) # (item_size, hidden_size) 92 | 93 | @torch.no_grad() 94 | def _encode_conv(self, train_infer_dataloader, device): 95 | self.model.eval() 96 | self.model.to(device) 97 | conv_encoded = [] 98 | with (autocast(dtype=torch.bfloat16, enabled=self.training_args['bf16'], device_type=device)): 99 | for batch in tqdm(train_infer_dataloader, desc="Encoding Conv..."): 100 | _, _, batch = batch 101 | inputs = {key: val.squeeze(1).to(device) if val is not None else val for key, val in batch.items()} 102 | outputs = self.model(**inputs) 103 | conv_encoded.extend(outputs["conv_embedding"].cpu()) 104 | return torch.stack(conv_encoded, dim=0) # (train_size, hidden_size) 105 | 106 | @torch.no_grad() 107 | def _sample_negatives(self, negative_num, device): 108 | conv_embedding = self._encode_conv(self.train_infer_dataloader, device=device) 109 | score = torch.matmul(conv_embedding, self.item_embeddings.T) 110 | 111 | # mask multiple positive items 112 | score.masked_fill_(self.train_conv_gt_mask, float('-inf')) 113 | score = torch.softmax(score, dim=-1) 114 | negative_list = torch.multinomial(score, num_samples=negative_num, 115 | replacement=False) 116 | return negative_list # (query_size, negative_num) 117 | 118 | def _ranking_loss(self, output_u, batch_item, batch_label, device): 119 | rep_u_conv = output_u["conv_embedding"] # (batch_size, hidden_size) 120 | 121 | total_score = torch.zeros(rep_u_conv.size(0), batch_item.size(1)).to(device) 122 | if 'c' in self.training_args['query_used_info']: 123 | total_score += torch.bmm(rep_u_conv.unsqueeze(1), batch_item.permute(0, 2, 1)).squeeze(1) 124 | if 'l' in self.training_args['query_used_info']: 125 | rep_u_like = output_u["like_embedding"] # (batch_size, hidden_size) 126 | total_score += self.model.alpha * torch.bmm(rep_u_like.unsqueeze(1), batch_item.permute(0, 2, 1)).squeeze(1) 127 | if 'd' in self.training_args['query_used_info']: 128 | rep_u_dislike = output_u["dislike_embedding"] 129 | total_score -= self.model.beta * torch.bmm(rep_u_dislike.unsqueeze(1), batch_item.permute(0, 2, 1)).squeeze( 130 | 1) 131 | total_score = total_score / self.training_args['temp'] 132 | loss = self.loss_fn(total_score, batch_label) 133 | return loss 134 | 135 | def _ranking_predict(self, output_u, device): 136 | rep_u_conv = output_u["conv_embedding"] # (batch_size, hidden_size) 137 | similarity_scores = torch.zeros(rep_u_conv.size(0), self.item_embeddings.size(0)).to(device) 138 | item_embeddings = self.item_embeddings.to(device) 139 | if 'c' in self.training_args['query_used_info']: 140 | similarity_scores += (rep_u_conv @ item_embeddings.T) 141 | if 'l' in self.training_args['query_used_info']: 142 | rep_u_like = output_u["like_embedding"] 143 | similarity_scores += self.model.alpha * (rep_u_like @ item_embeddings.T) 144 | if 'd' in self.training_args['query_used_info']: 145 | rep_u_dislike = output_u["dislike_embedding"] 146 | similarity_scores -= self.model.beta * (rep_u_dislike @ item_embeddings.T) 147 | return similarity_scores 148 | 149 | def train_one_epoch(self, epoch, optimizer, scheduler, device): 150 | print("\n>> Epoch: ", epoch + 1) 151 | self.model.to(device) 152 | self.model.train() 153 | 154 | self.item_embeddings = self._encode_items(self.item_dataloader, device=device).to('cpu') 155 | self.negative_samples = self._sample_negatives(self.training_args['negative_sample'] - 1, device) # (query_size, negative_num) 156 | 157 | total_train_loss = 0 158 | tq = tqdm(self.train_dataloader, desc=f"Training Epoch {epoch + 1}") 159 | for step, batch in enumerate(tq): 160 | batch_u_id, batch_gt_id, input_u = batch 161 | input_u = {key: val.to(device) if val is not None else val for key, val in input_u.items()} 162 | 163 | batch_negative_samples = self.negative_samples[batch_u_id] # (batch_size, negative_num-1) 164 | batch_negative = torch.stack([self.item_embeddings[idx] for idx in batch_negative_samples], dim=0).to( 165 | device) # (batch_size, negative_num-1, hidden_size) 166 | batch_positive = torch.stack([self.item_embeddings[idx] for idx in batch_gt_id], dim=0).to( 167 | device).unsqueeze(1) # (batch_size, 1, hidden_size) 168 | batch_item = torch.cat([batch_positive, batch_negative], 169 | dim=1) # (batch_size, negative_num, hidden_size) 170 | batch_label = torch.zeros(len(batch_u_id), dtype=torch.long).to(device) 171 | 172 | with autocast(dtype=torch.bfloat16, enabled=self.training_args['bf16'], device_type=device): 173 | output_u = self.model(**input_u) 174 | train_loss = self._ranking_loss(output_u, batch_item, batch_label, self.training_args, device) 175 | 176 | loss_item = train_loss.item() 177 | if self.wandb_logger is not None: 178 | self.wandb_logger.log({f"train_step_loss": loss_item}) 179 | 180 | train_loss = train_loss / self.accumulation_steps 181 | train_loss.backward() 182 | 183 | if (step + 1) % self.accumulation_steps == 0: 184 | optimizer.step() 185 | scheduler.step() 186 | optimizer.zero_grad() 187 | 188 | total_train_loss += loss_item 189 | tq.set_postfix(loss=loss_item) 190 | 191 | @torch.no_grad() 192 | def valid_one_epoch(self, epoch, device): 193 | self.model.to(device) 194 | self.model.eval() 195 | 196 | total_valid_loss = 0 197 | total_top_k_document = [] 198 | for step, batch in enumerate(tqdm(self.valid_dataloader, desc=f"Validation Epoch {epoch + 1}")): 199 | batch_u_id, batch_gt_id, input_u = batch 200 | input_u = {key: val.to(device) if val is not None else val for key, val in input_u.items()} 201 | 202 | batch_negative_samples = self.negative_samples[batch_u_id] # (batch_size, negative_num-1) 203 | batch_negative = torch.stack([self.item_embeddings[idx] for idx in batch_negative_samples], dim=0).to( 204 | device) # (batch_size, negative_num-1, hidden_size) 205 | batch_positive = torch.stack([self.item_embeddings[idx] for idx in batch_gt_id], dim=0).to( 206 | device).unsqueeze(1) # (batch_size, 1, hidden_size) 207 | batch_item = torch.cat([batch_positive, batch_negative], 208 | dim=1) # (batch_size, negative_num, hidden_size) 209 | batch_label = torch.zeros(len(batch_u_id), dtype=torch.long).to(device) 210 | 211 | with autocast(dtype=torch.bfloat16, enabled=self.training_args['bf16'], device_type=device): 212 | output_u = self.model(**input_u) 213 | valid_loss = self._ranking_loss(output_u, batch_item, batch_label, self.training_args, device) 214 | similarity_scores = self._ranking_predict(output_u, self.training_args, device) 215 | 216 | batch_top_k_indices = similarity_scores.topk(k=max(self.training_args['cutoff']), dim=-1).indices.tolist() 217 | 218 | for top_k_indices in batch_top_k_indices: 219 | total_top_k_document.append(top_k_indices) 220 | 221 | total_valid_loss += valid_loss.item() 222 | 223 | if self.wandb_logger is not None: 224 | self.wandb_logger.log({f"valid_step_loss": valid_loss.item()}) 225 | total_valid_loss += valid_loss.item() 226 | 227 | print(">> Validation results:") 228 | performance = {} 229 | for k in self.training_args['cutoff']: 230 | recall = recall_at_k(self.valid_gt_ids, total_top_k_document, k) 231 | ndcg = ndcg_at_k(self.valid_gt_ids, total_top_k_document, k) 232 | performance[f'valid/Recall@{k}'] = recall 233 | performance[f'valid/NDCG@{k}'] = ndcg 234 | 235 | print(f">>> Recall@{k}: {recall:.4f}, NDCG@{k}: {ndcg:.4f}") 236 | 237 | if self.wandb_logger is not None: 238 | self.wandb_logger.log(performance) 239 | 240 | return ndcg_at_k(self.valid_gt_ids, total_top_k_document, 10) 241 | 242 | @torch.no_grad() 243 | def inference(self, device, zero_shot): 244 | self.model.to(device) 245 | self.model.eval() 246 | 247 | total_top_k_document = [] 248 | for _, batch in enumerate(tqdm(self.test_dataloader, desc=f"Testing")): 249 | _, _, input_u = batch 250 | input_u = {key: val.to(device) if val is not None else val for key, val in input_u.items()} 251 | with autocast(dtype=torch.bfloat16, enabled=self.training_args['bf16'], device_type=device): 252 | output_u = self.model(**input_u) 253 | similarity_scores = self._ranking_predict(output_u, self.training_args, device) 254 | 255 | batch_top_k_indices = similarity_scores.topk(k=max(self.training_args['cutoff']), dim=-1).indices.tolist() 256 | 257 | for top_k_indices in batch_top_k_indices: 258 | total_top_k_document.append(top_k_indices) 259 | 260 | print(">> Test results:") 261 | performance = {} 262 | for k in self.training_args['cutoff']: 263 | recall = recall_at_k(self.test_gt_ids, total_top_k_document, k) 264 | ndcg = ndcg_at_k(self.test_gt_ids, total_top_k_document, k) 265 | if zero_shot: 266 | performance[f'zero_shot/Recall@{k}'] = recall 267 | performance[f'zero_shot/NDCG@{k}'] = ndcg 268 | else: 269 | performance[f'test/Recall@{k}'] = recall 270 | performance[f'test/NDCG@{k}'] = ndcg 271 | 272 | print(f">>> Recall@{k}: {recall:.4f}, NDCG@{k}: {ndcg:.4f}") 273 | 274 | recommendation_results = [{'gt': self.test_gt_ids[idx], 'recommendation': total_top_k_document[idx]} 275 | for idx in range(len(self.test_gt_ids))] 276 | 277 | if self.wandb_logger is not None: 278 | self.wandb_logger.log(performance) 279 | 280 | return performance, recommendation_results 281 | 282 | def train(self, ckpt_save_path, device): 283 | 284 | optimizer = self._set_optimizer(self.training_args['learning_rate']) 285 | scheduler = self._set_scheduler(optimizer=optimizer) 286 | epochs = self.training_args['epochs'] 287 | 288 | self.model.to(device) 289 | print("\n> Zero-shot Performance.") 290 | self.item_embeddings = self._encode_items(self.item_dataloader, device=device).to(device) 291 | performance, recommendation_results = self.inference(device=device, zero_shot=True) 292 | 293 | if self.training_args['zero_shot']: 294 | return performance, recommendation_results 295 | 296 | best_val_ndcg_at_k = -np.inf 297 | patience = 0 298 | for epoch in range(epochs): 299 | self.train_one_epoch(epoch, optimizer, scheduler, device=device) 300 | avg_val_ndcg_at_k = self.valid_one_epoch(epoch, device=device) 301 | 302 | if best_val_ndcg_at_k <= avg_val_ndcg_at_k: 303 | best_val_ndcg_at_k = avg_val_ndcg_at_k 304 | patience = 0 305 | self.model.save_checkpoint(ckpt_save_path=ckpt_save_path) 306 | print(f">> Best model saved with NDCG@10: {best_val_ndcg_at_k:.4f}") 307 | else: 308 | patience += 1 309 | 310 | if patience == self.training_args['patience']: 311 | print(f">> Early stopped after {epoch + 1} epochs.") 312 | break 313 | 314 | self.model.load_best_checkpoint(ckpt_save_path=ckpt_save_path) 315 | performance, recommendation_results = self.inference(device=device, zero_shot=False) 316 | 317 | return performance, recommendation_results 318 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | from datetime import datetime 5 | from pathlib import Path 6 | 7 | import torch 8 | import wandb 9 | from torch.utils.data import DataLoader 10 | from wonderwords import RandomWord 11 | 12 | from src import CoralTrainer, CoralDataset, CoralItemDataset, CoralEncoder 13 | from utils import load_query, load_document, set_randomness 14 | 15 | 16 | def main(training_args): 17 | training_args = vars(training_args) 18 | print("> Training arguments") 19 | for k, v in training_args.items(): 20 | print(f'{k}: {v}') 21 | 22 | data_path = Path('data') / training_args['data_name'] 23 | train_path = data_path / 'input_processed_train.jsonl' 24 | valid_path = data_path / 'input_processed_valid.jsonl' 25 | test_path = data_path / 'input_processed_test.jsonl' 26 | document_path = data_path / 'processed_document.json' 27 | 28 | train_gt_ids, train_conv_gt_ids, train_u = load_query(file_path=train_path, 29 | used_info=training_args['query_used_info']) 30 | valid_gt_ids, _, valid_u = load_query(file_path=valid_path, 31 | used_info=training_args['query_used_info']) 32 | test_gt_ids, _, test_u = load_query(file_path=test_path, 33 | used_info=training_args['query_used_info']) 34 | document_ids, document = load_document(file_path=document_path, data_name=training_args['data_name'], 35 | used_info=training_args['doc_used_info']) 36 | # convert item(e.g., 'Hoffa (1992)' id to idx (e.g., 1628) 37 | idx2item = document_ids 38 | item2idx = {v: k for k, v in enumerate(idx2item)} 39 | for idx, cgt in enumerate(train_conv_gt_ids): 40 | train_conv_gt_ids[idx] = [item2idx[i] for i in cgt] 41 | train_gt_ids = [item2idx[i] for i in train_gt_ids] 42 | valid_gt_ids = [item2idx[i] for i in valid_gt_ids] 43 | test_gt_ids = [[item2idx[i] for i in test] for test in test_gt_ids] 44 | 45 | print("\n\n> Sample data:") 46 | if 'c' in training_args['query_used_info']: 47 | print(train_u['c'][0]) 48 | if 'l' in training_args['query_used_info']: 49 | print(train_u['l'][0]) 50 | if 'd' in training_args['query_used_info']: 51 | print(train_u['d'][0]) 52 | 53 | 54 | train_dataset = CoralDataset(user_text=train_u, gt_ids=train_gt_ids, 55 | base_tokenizer_name=training_args['base_model_name']) 56 | valid_dataset = CoralDataset(user_text=valid_u, gt_ids=valid_gt_ids, 57 | base_tokenizer_name=training_args['base_model_name']) 58 | test_dataset = CoralDataset(user_text=test_u, gt_ids=test_gt_ids, 59 | base_tokenizer_name=training_args['base_model_name']) 60 | item_dataset = CoralItemDataset(item_text=document, base_tokenizer_name=training_args['base_model_name']) 61 | 62 | train_dataloader = DataLoader(train_dataset, batch_size=training_args['batch_size'], shuffle=True, 63 | collate_fn=train_dataset.collate_fn) 64 | valid_dataloader = DataLoader(valid_dataset, batch_size=2 * training_args['batch_size'], shuffle=False, 65 | collate_fn=valid_dataset.collate_fn) 66 | test_dataloader = DataLoader(test_dataset, batch_size=2 * training_args['batch_size'], shuffle=False, 67 | collate_fn=test_dataset.collate_fn) 68 | train_infer_dataloader = DataLoader(train_dataset, batch_size=4 * training_args['batch_size'], shuffle=False, 69 | collate_fn=train_dataset.collate_fn) 70 | item_dataloader = DataLoader(item_dataset, batch_size=4 * training_args['batch_size'], shuffle=False, 71 | collate_fn=item_dataset.collate_fn) 72 | 73 | random_word_generator = RandomWord() 74 | while True: 75 | random_word = random_word_generator.random_words(include_parts_of_speech=["noun", "verb"])[0] 76 | if " " in random_word or "-" in random_word: 77 | continue 78 | else: 79 | break 80 | random_word_and_date = random_word + "_" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S") 81 | 82 | if training_args['wandb_project'] is not None: 83 | wandb_tag = [ 84 | training_args['data_name'], 85 | f"seed_{args.seed}", 86 | f"query_used_info_{training_args['query_used_info']}", 87 | f"doc_used_info_{training_args['doc_used_info']}", 88 | ] 89 | wandb_logger = wandb.init( 90 | project=args.wandb_project, 91 | entity=args.wandb_entity, 92 | name=random_word_and_date, 93 | group=args.wandb_group if args.wandb_group else training_args['data_name'], 94 | config=training_args, 95 | tags=wandb_tag 96 | ) 97 | else: 98 | wandb_logger = None 99 | 100 | model = CoralEncoder(base_model_name=training_args['base_model_name'], 101 | alpha=training_args['alpha'], 102 | beta=training_args['beta']) 103 | 104 | trainer = CoralTrainer( 105 | model=model, 106 | train_dataloader=train_dataloader, 107 | valid_dataloader=valid_dataloader, 108 | test_dataloader=test_dataloader, 109 | train_infer_dataloader=train_infer_dataloader, 110 | item_dataloader=item_dataloader, 111 | item_num=len(document_ids), 112 | train_conv_gt_ids=train_conv_gt_ids, 113 | train_gt_ids=train_gt_ids, 114 | valid_gt_ids=valid_gt_ids, 115 | test_gt_ids=test_gt_ids, 116 | optimizer='Adam', 117 | scheduler='get_linear_schedule_with_warmup', 118 | accumulation_steps=training_args['accumulation_steps'], 119 | training_args=training_args, 120 | wandb_logger=wandb_logger 121 | ) 122 | 123 | ckpt_save_path = os.path.join(training_args['ckpt_save_path'], training_args['data_name'], 124 | random_word_and_date) 125 | try: 126 | if not os.path.exists(ckpt_save_path): 127 | os.makedirs(ckpt_save_path) 128 | except OSError: 129 | print("Error: Failed to create the directory.") 130 | 131 | device = "cuda" if torch.cuda.is_available() else "cpu" 132 | performance, recommendation_results = trainer.train(ckpt_save_path, device=device) 133 | 134 | output_results_path = Path('outputs') / 'results' 135 | os.makedirs(output_results_path, exist_ok=True) 136 | 137 | with open(os.path.join(output_results_path, f'{random_word_and_date}.json'), 'w') as f: 138 | json.dump({ 139 | 'version': random_word_and_date, 140 | 'arguments': training_args, 141 | 'performance': performance, 142 | 'recommendation_results': recommendation_results 143 | }, f, indent=1, ensure_ascii=False) 144 | 145 | 146 | def parse_args(): 147 | parser = argparse.ArgumentParser() 148 | # mode 149 | parser.add_argument('--zero_shot', action='store_true') 150 | # data 151 | parser.add_argument('--data_name', type=str, required=True, choices=['inspired', 'redial', 'pearl']) 152 | # hyperparameter 153 | parser.add_argument('--alpha', type=float, default=0.5) 154 | parser.add_argument('--beta', type=float, default=0.2) 155 | parser.add_argument('--temp', type=float, default=0.05) 156 | parser.add_argument("--query_used_info", type=str, nargs='+', choices=['c', 'l', 'd'], default=['c', 'l', 'd']) 157 | parser.add_argument("--doc_used_info", type=str, nargs='+', choices=['m', 'pref'], default=['m', 'pref']) 158 | # train 159 | parser.add_argument("--bf16", action="store_true") 160 | parser.add_argument("--negative_sample", type=int, default=16) 161 | parser.add_argument('--epochs', type=int, default=100) 162 | parser.add_argument('--accumulation_steps', type=int, default=8) 163 | parser.add_argument('--base_model_name', type=str, default='nvidia/NV-Embed-v1') 164 | parser.add_argument('--learning_rate', type=float, default=1e-4) 165 | parser.add_argument('--patience', type=int, default=5) 166 | parser.add_argument('--batch_size', type=int, default=10) 167 | parser.add_argument('--ckpt_save_path', type=str, default='checkpoints') 168 | parser.add_argument('--seed', type=int, default=2024) 169 | parser.add_argument('--cutoff', type=int, nargs='+', default=[5, 10, 50]) 170 | # wandb 171 | parser.add_argument("--wandb_project", type=str) 172 | parser.add_argument("--wandb_entity", type=str) 173 | parser.add_argument('--wandb_group', type=str) 174 | 175 | args = parser.parse_args() 176 | return args 177 | 178 | 179 | if __name__ == '__main__': 180 | args = parse_args() 181 | set_randomness(args.seed) 182 | 183 | main(args) 184 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | import numpy as np 4 | import torch 5 | 6 | def set_randomness(seed, deterministic=False): 7 | random.seed(seed) 8 | np.random.seed(seed) 9 | torch.manual_seed(seed) 10 | torch.cuda.manual_seed_all(seed) 11 | if deterministic: 12 | torch.backends.cudnn.deterministic = True 13 | torch.backends.cudnn.benchmark = False 14 | 15 | def load_jsonl(file_path): 16 | with open(file_path, 'r', encoding='utf8') as f: 17 | return [json.loads(line) for line in f] 18 | 19 | 20 | def load_json(file_path): 21 | with open(file_path, 'r', encoding='utf8') as f: 22 | return json.load(f) 23 | 24 | 25 | def format_document(doc_info, data_name, used_info): 26 | doc = '' 27 | if 'm' in used_info: 28 | if 'pearl' in data_name: 29 | doc += (f"Title: {doc_info['title']}. Genre: {doc_info['genre']}. Cast:" 30 | f" {', '.join(doc_info['cast'])}. " 31 | f"Director: {', '.join(doc_info['director'])}\n") 32 | elif 'inspired' in data_name: 33 | doc += ( 34 | f"Title: {doc_info['title']}. Genre: {doc_info['genre']}. Cast: {', '.join(doc_info['cast'])}. " 35 | f"Director: {', '.join(doc_info['director'])}\n") 36 | elif 'redial' in data_name: 37 | doc += ( 38 | f"Title: {doc_info['title']}. Genre: {doc_info['genre']}. Cast: {', '.join(doc_info['cast'])}. " 39 | f"Director: {', '.join(doc_info['director'])}\n") 40 | else: 41 | raise ValueError(f"Invalid data_name in formatting document: {data_name}") 42 | if 'pref' in used_info: 43 | doc += ', '.join(doc_info['like']) 44 | doc += ', '.join(doc_info['dislike']) 45 | 46 | return doc 47 | 48 | 49 | def load_document(file_path, data_name, used_info=None): 50 | document = load_json(file_path) 51 | item_ids = list(document.keys()) 52 | document_dict = {'item': [format_document(document[_id], data_name, used_info) for _id in item_ids]} 53 | return item_ids, document_dict 54 | 55 | 56 | def format_query(query, used_info): 57 | query_str = '' 58 | if 'c' in used_info: 59 | query_str += f"{query['input_dialog_history']}" 60 | elif 'l' in used_info: 61 | query_str += f"Like Preference: {','.join(query['preference']['like'])}" 62 | elif 'd' in used_info: 63 | query_str += f"Dislike Preferences: {','.join(query['preference']['dislike'])}" 64 | else: 65 | raise ValueError(f"Invalid used_info in formatting query: {used_info}") 66 | 67 | return query_str 68 | 69 | 70 | def load_query(file_path, used_info): 71 | queries = load_jsonl(file_path) 72 | 73 | if 'test' in str(file_path): 74 | gt_ids = [] 75 | for input_ in queries: 76 | gt_ids.append([gt['id'] for gt in input_['gt']]) # gt['id']: list of items 77 | conv_gt_ids = None 78 | 79 | elif 'train' in str(file_path): 80 | gt_ids = [input_['train_gt']['id'] for input_ in queries] # input_['train_gt']['id']: gt['id']: items 81 | conv_gt_ids = [] 82 | for input_ in queries: 83 | conv_gt_ids.append([gt['id'] for gt in input_['gt']]) 84 | else: 85 | gt_ids = [input_['gt']['id'] for input_ in queries] # input_['gt']['id']: gt['id']: items 86 | conv_gt_ids = None 87 | 88 | 89 | query_dict = {} 90 | for key in used_info: 91 | if key not in ['c', 'l', 'd']: 92 | raise ValueError(f"Invalid used_info in preparing query: {used_info}") 93 | query_dict[key] = [format_query(input_, key) for input_ in queries] 94 | 95 | return gt_ids, conv_gt_ids, query_dict 96 | 97 | --------------------------------------------------------------------------------