├── .github └── dependabot.yml ├── .gitignore ├── CODE_OF_CONDUCT.md ├── LICENSE ├── README.md ├── SECURITY.md ├── SUPPORT.md ├── args.py ├── compute_metrics.py ├── data ├── __init__.py └── dataloader_musee.py ├── experiment_musee.py ├── img ├── metric.png └── model.png ├── metric.py ├── model ├── __init__.py ├── t5_with_t5decoder.py ├── t5_with_t5decoder_emb.py └── utils │ ├── __init__.py │ ├── dist_util.py │ ├── fp16_util.py │ ├── logger.py │ ├── losses.py │ └── nn.py ├── requirements.txt ├── trainer ├── __init__.py └── trainer_musee.py └── utils.py /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # To get started with Dependabot version updates, you'll need to specify which 2 | # package ecosystems to update and where the package manifests are located. 3 | # Please see the documentation for all configuration options: 4 | # https://docs.github.com/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file 5 | 6 | version: 2 7 | updates: 8 | - package-ecosystem: "pip" 9 | directory: "/" 10 | schedule: 11 | interval: "weekly" 12 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ## Ignore Visual Studio temporary files, build results, and 2 | ## files generated by popular Visual Studio add-ons. 3 | ## 4 | ## Get latest from https://github.com/github/gitignore/blob/main/VisualStudio.gitignore 5 | 6 | # User-specific files 7 | *.rsuser 8 | *.suo 9 | *.user 10 | *.userosscache 11 | *.sln.docstates 12 | 13 | # User-specific files (MonoDevelop/Xamarin Studio) 14 | *.userprefs 15 | 16 | # Mono auto generated files 17 | mono_crash.* 18 | 19 | # Build results 20 | [Dd]ebug/ 21 | [Dd]ebugPublic/ 22 | [Rr]elease/ 23 | [Rr]eleases/ 24 | x64/ 25 | x86/ 26 | [Ww][Ii][Nn]32/ 27 | [Aa][Rr][Mm]/ 28 | [Aa][Rr][Mm]64/ 29 | bld/ 30 | [Bb]in/ 31 | [Oo]bj/ 32 | [Ll]og/ 33 | [Ll]ogs/ 34 | 35 | # Visual Studio 2015/2017 cache/options directory 36 | .vs/ 37 | # Uncomment if you have tasks that create the project's static files in wwwroot 38 | #wwwroot/ 39 | 40 | # Visual Studio 2017 auto generated files 41 | Generated\ Files/ 42 | 43 | # MSTest test Results 44 | [Tt]est[Rr]esult*/ 45 | [Bb]uild[Ll]og.* 46 | 47 | # NUnit 48 | *.VisualState.xml 49 | TestResult.xml 50 | nunit-*.xml 51 | 52 | # Build Results of an ATL Project 53 | [Dd]ebugPS/ 54 | [Rr]eleasePS/ 55 | dlldata.c 56 | 57 | # Benchmark Results 58 | BenchmarkDotNet.Artifacts/ 59 | 60 | # .NET Core 61 | project.lock.json 62 | project.fragment.lock.json 63 | artifacts/ 64 | 65 | # ASP.NET Scaffolding 66 | ScaffoldingReadMe.txt 67 | 68 | # StyleCop 69 | StyleCopReport.xml 70 | 71 | # Files built by Visual Studio 72 | *_i.c 73 | *_p.c 74 | *_h.h 75 | *.ilk 76 | *.meta 77 | *.obj 78 | *.iobj 79 | *.pch 80 | *.pdb 81 | *.ipdb 82 | *.pgc 83 | *.pgd 84 | *.rsp 85 | *.sbr 86 | *.tlb 87 | *.tli 88 | *.tlh 89 | *.tmp 90 | *.tmp_proj 91 | *_wpftmp.csproj 92 | *.log 93 | *.tlog 94 | *.vspscc 95 | *.vssscc 96 | .builds 97 | *.pidb 98 | *.svclog 99 | *.scc 100 | 101 | # Chutzpah Test files 102 | _Chutzpah* 103 | 104 | # Visual C++ cache files 105 | ipch/ 106 | *.aps 107 | *.ncb 108 | *.opendb 109 | *.opensdf 110 | *.sdf 111 | *.cachefile 112 | *.VC.db 113 | *.VC.VC.opendb 114 | 115 | # Visual Studio profiler 116 | *.psess 117 | *.vsp 118 | *.vspx 119 | *.sap 120 | 121 | # Visual Studio Trace Files 122 | *.e2e 123 | 124 | # TFS 2012 Local Workspace 125 | $tf/ 126 | 127 | # Guidance Automation Toolkit 128 | *.gpState 129 | 130 | # ReSharper is a .NET coding add-in 131 | _ReSharper*/ 132 | *.[Rr]e[Ss]harper 133 | *.DotSettings.user 134 | 135 | # TeamCity is a build add-in 136 | _TeamCity* 137 | 138 | # DotCover is a Code Coverage Tool 139 | *.dotCover 140 | 141 | # AxoCover is a Code Coverage Tool 142 | .axoCover/* 143 | !.axoCover/settings.json 144 | 145 | # Coverlet is a free, cross platform Code Coverage Tool 146 | coverage*.json 147 | coverage*.xml 148 | coverage*.info 149 | 150 | # Visual Studio code coverage results 151 | *.coverage 152 | *.coveragexml 153 | 154 | # NCrunch 155 | _NCrunch_* 156 | .*crunch*.local.xml 157 | nCrunchTemp_* 158 | 159 | # MightyMoose 160 | *.mm.* 161 | AutoTest.Net/ 162 | 163 | # Web workbench (sass) 164 | .sass-cache/ 165 | 166 | # Installshield output folder 167 | [Ee]xpress/ 168 | 169 | # DocProject is a documentation generator add-in 170 | DocProject/buildhelp/ 171 | DocProject/Help/*.HxT 172 | DocProject/Help/*.HxC 173 | DocProject/Help/*.hhc 174 | DocProject/Help/*.hhk 175 | DocProject/Help/*.hhp 176 | DocProject/Help/Html2 177 | DocProject/Help/html 178 | 179 | # Click-Once directory 180 | publish/ 181 | 182 | # Publish Web Output 183 | *.[Pp]ublish.xml 184 | *.azurePubxml 185 | # Note: Comment the next line if you want to checkin your web deploy settings, 186 | # but database connection strings (with potential passwords) will be unencrypted 187 | *.pubxml 188 | *.publishproj 189 | 190 | # Microsoft Azure Web App publish settings. Comment the next line if you want to 191 | # checkin your Azure Web App publish settings, but sensitive information contained 192 | # in these scripts will be unencrypted 193 | PublishScripts/ 194 | 195 | # NuGet Packages 196 | *.nupkg 197 | # NuGet Symbol Packages 198 | *.snupkg 199 | # The packages folder can be ignored because of Package Restore 200 | **/[Pp]ackages/* 201 | # except build/, which is used as an MSBuild target. 202 | !**/[Pp]ackages/build/ 203 | # Uncomment if necessary however generally it will be regenerated when needed 204 | #!**/[Pp]ackages/repositories.config 205 | # NuGet v3's project.json files produces more ignorable files 206 | *.nuget.props 207 | *.nuget.targets 208 | 209 | # Microsoft Azure Build Output 210 | csx/ 211 | *.build.csdef 212 | 213 | # Microsoft Azure Emulator 214 | ecf/ 215 | rcf/ 216 | 217 | # Windows Store app package directories and files 218 | AppPackages/ 219 | BundleArtifacts/ 220 | Package.StoreAssociation.xml 221 | _pkginfo.txt 222 | *.appx 223 | *.appxbundle 224 | *.appxupload 225 | 226 | # Visual Studio cache files 227 | # files ending in .cache can be ignored 228 | *.[Cc]ache 229 | # but keep track of directories ending in .cache 230 | !?*.[Cc]ache/ 231 | 232 | # Others 233 | ClientBin/ 234 | ~$* 235 | *~ 236 | *.dbmdl 237 | *.dbproj.schemaview 238 | *.jfm 239 | *.pfx 240 | *.publishsettings 241 | orleans.codegen.cs 242 | 243 | # Including strong name files can present a security risk 244 | # (https://github.com/github/gitignore/pull/2483#issue-259490424) 245 | #*.snk 246 | 247 | # Since there are multiple workflows, uncomment next line to ignore bower_components 248 | # (https://github.com/github/gitignore/pull/1529#issuecomment-104372622) 249 | #bower_components/ 250 | 251 | # RIA/Silverlight projects 252 | Generated_Code/ 253 | 254 | # Backup & report files from converting an old project file 255 | # to a newer Visual Studio version. Backup files are not needed, 256 | # because we have git ;-) 257 | _UpgradeReport_Files/ 258 | Backup*/ 259 | UpgradeLog*.XML 260 | UpgradeLog*.htm 261 | ServiceFabricBackup/ 262 | *.rptproj.bak 263 | 264 | # SQL Server files 265 | *.mdf 266 | *.ldf 267 | *.ndf 268 | 269 | # Business Intelligence projects 270 | *.rdl.data 271 | *.bim.layout 272 | *.bim_*.settings 273 | *.rptproj.rsuser 274 | *- [Bb]ackup.rdl 275 | *- [Bb]ackup ([0-9]).rdl 276 | *- [Bb]ackup ([0-9][0-9]).rdl 277 | 278 | # Microsoft Fakes 279 | FakesAssemblies/ 280 | 281 | # GhostDoc plugin setting file 282 | *.GhostDoc.xml 283 | 284 | # Node.js Tools for Visual Studio 285 | .ntvs_analysis.dat 286 | node_modules/ 287 | 288 | # Visual Studio 6 build log 289 | *.plg 290 | 291 | # Visual Studio 6 workspace options file 292 | *.opt 293 | 294 | # Visual Studio 6 auto-generated workspace file (contains which files were open etc.) 295 | *.vbw 296 | 297 | # Visual Studio 6 auto-generated project file (contains which files were open etc.) 298 | *.vbp 299 | 300 | # Visual Studio 6 workspace and project file (working project files containing files to include in project) 301 | *.dsw 302 | *.dsp 303 | 304 | # Visual Studio 6 technical files 305 | *.ncb 306 | *.aps 307 | 308 | # Visual Studio LightSwitch build output 309 | **/*.HTMLClient/GeneratedArtifacts 310 | **/*.DesktopClient/GeneratedArtifacts 311 | **/*.DesktopClient/ModelManifest.xml 312 | **/*.Server/GeneratedArtifacts 313 | **/*.Server/ModelManifest.xml 314 | _Pvt_Extensions 315 | 316 | # Paket dependency manager 317 | .paket/paket.exe 318 | paket-files/ 319 | 320 | # FAKE - F# Make 321 | .fake/ 322 | 323 | # CodeRush personal settings 324 | .cr/personal 325 | 326 | # Python Tools for Visual Studio (PTVS) 327 | __pycache__/ 328 | *.pyc 329 | 330 | # Cake - Uncomment if you are using it 331 | # tools/** 332 | # !tools/packages.config 333 | 334 | # Tabs Studio 335 | *.tss 336 | 337 | # Telerik's JustMock configuration file 338 | *.jmconfig 339 | 340 | # BizTalk build output 341 | *.btp.cs 342 | *.btm.cs 343 | *.odx.cs 344 | *.xsd.cs 345 | 346 | # OpenCover UI analysis results 347 | OpenCover/ 348 | 349 | # Azure Stream Analytics local run output 350 | ASALocalRun/ 351 | 352 | # MSBuild Binary and Structured Log 353 | *.binlog 354 | 355 | # NVidia Nsight GPU debugger configuration file 356 | *.nvuser 357 | 358 | # MFractors (Xamarin productivity tool) working folder 359 | .mfractor/ 360 | 361 | # Local History for Visual Studio 362 | .localhistory/ 363 | 364 | # Visual Studio History (VSHistory) files 365 | .vshistory/ 366 | 367 | # BeatPulse healthcheck temp database 368 | healthchecksdb 369 | 370 | # Backup folder for Package Reference Convert tool in Visual Studio 2017 371 | MigrationBackup/ 372 | 373 | # Ionide (cross platform F# VS Code tools) working folder 374 | .ionide/ 375 | 376 | # Fody - auto-generated XML schema 377 | FodyWeavers.xsd 378 | 379 | # VS Code files for those working on multiple tools 380 | .vscode/* 381 | !.vscode/settings.json 382 | !.vscode/tasks.json 383 | !.vscode/launch.json 384 | !.vscode/extensions.json 385 | *.code-workspace 386 | 387 | # Local History for Visual Studio Code 388 | .history/ 389 | 390 | # Windows Installer files from build outputs 391 | *.cab 392 | *.msi 393 | *.msix 394 | *.msm 395 | *.msp 396 | 397 | # JetBrains Rider 398 | *.sln.iml 399 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Learning to Extract Structured Entities Using Language Models 2 | 3 |

4 | 5 | License 6 | 7 |

8 | 9 | [![CodeQL](https://github.com/microsoft/Structured-Entity-Extraction/actions/workflows/github-code-scanning/codeql/badge.svg)](https://github.com/microsoft/Structured-Entity-Extraction/actions/workflows/github-code-scanning/codeql) 10 | 11 | [![Dependabot Updates](https://github.com/microsoft/Structured-Entity-Extraction/actions/workflows/dependabot/dependabot-updates/badge.svg)](https://github.com/microsoft/Structured-Entity-Extraction/actions/workflows/dependabot/dependabot-updates) 12 | 13 | 14 | **[🔥 Oral, top 7% of all accepted papers 🔥]** 15 | 16 | ⚙️ This is the implementation of our collaboration between **MSR** and **Mila**, ["**Learning to Extract Structured Entities Using Language Models**"](https://arxiv.org/pdf/2402.04437), accepted to **EMNLP 2024 Main conference**. 17 | 18 | ## Abstract 19 | Recent advances in machine learning have significantly impacted the field of information extraction, with Language Models (LMs) playing a pivotal role in extracting structured information from unstructured text. Prior works typically represent information extraction as triplet-centric and use classical metrics such as precision and recall for evaluation. We reformulate the task to be entity-centric, enabling the use of diverse metrics that can provide more insights from various perspectives. We contribute to the field by introducing Structured Entity Extraction and proposing the Approximate Entity Set OverlaP (AESOP) metric, designed to appropriately assess model performance. Later, we introduce a new Multi-stage Structured Entity Extraction (MuSEE) model that harnesses the power of LMs for enhanced effectiveness and efficiency by decomposing the extraction task into multiple stages. Quantitative and human side-by-side evaluations confirm that our model outperforms baselines, offering promising directions for future advancements in structured entity extraction. 20 | 21 | 22 | 23 | 24 | ## Install Dependencies 25 | ```bash 26 | conda create -n MuSEE python=3.8 --file requirements.txt 27 | conda activate MuSEE 28 | ``` 29 | 30 | ## Directory structure 31 | ``` 32 | data/ # Dataset generation code will be released soon due to internal process to go through. 33 | |-- GPT4-based/ # GPT4-based dataset 34 | |-- Wikidata-based/ # Wikidata-based dataset 35 | |-- nyt/ # New York Times Relation Extraction dataset 36 | |-- conll04/ # CoNLL04 dataset 37 | |-- REBEL/ # REBEL dataset 38 | |-- TREX/ # T-REx dataset 39 | |-- dataloader_musee.py # Dataloader for MuSEE model 40 | model/ 41 | |-- t5_with_t5decoder.py # base model architecture for MuSEE 42 | trainer/ 43 | |-- trainer_musee.py # Trainer for MuSEE model 44 | args.py # Arguments for MuSEE model and running experiments 45 | experiment_musee.py # Main file to run experiments 46 | metric.py # Calculate different variants of the proposed AESOP metric 47 | compute_metrics.py # Calculate metrics for the entire dataset 48 | requirements.txt # Required packages 49 | utils.py # Utility functions 50 | ``` 51 | 52 | ## Run the code 53 | ``` 54 | python experiment_musee.py \ 55 | --model_choice=musee \ 56 | --dataset=gpt4 \ 57 | --pretrained_model_name=t5-large \ 58 | --batch_size=1 \ 59 | --epochs=100 \ 60 | --log_wandb=True \ 61 | --use_lora=True \ 62 | --lr=1e-4 \ 63 | --weight_decay=1e-2 \ 64 | --mode=train \ 65 | --loss_mode=mean \ 66 | --use_better_init=True 67 | ``` 68 | 69 | ## Citation and Contact 70 | If you find this paper useful, please cite our work: 71 | ``` 72 | @inproceedings{wu2024structured, 73 | title={Structured Entity Extraction Using Large Language Models}, 74 | author={Haolun Wu, Ye Yuan, Liana Mikaelyan, Alexander Meulemans, Xue Liu, James Hensman, and Bhaskar Mitra}, 75 | booktitle = "Proceedings of the 2024 Conference on Empirical Methods in Natural Language Processing", 76 | month = nov, 77 | year = "2024", 78 | address = "Miami, USA", 79 | publisher = "Association for Computational Linguistics", 80 | } 81 | ``` 82 | 83 | 💬 If you have any questions, feel free to contact us through email (haolun.wu@mail.mcgill.ca, ye.yuan3@mail.mcgill.ca) or Github issues. Enjoy! 84 | 85 | 86 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet) and [Xamarin](https://github.com/xamarin). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/security.md/definition), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/security.md/msrc/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/security.md/msrc/pgp). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/security.md/msrc/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/security.md/cvd). 40 | 41 | 42 | -------------------------------------------------------------------------------- /SUPPORT.md: -------------------------------------------------------------------------------- 1 | # TODO: The maintainer of this repo has not yet edited this file 2 | 3 | **REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project? 4 | 5 | - **No CSS support:** Fill out this template with information about how to file issues and get help. 6 | - **Yes CSS support:** Fill out an intake form at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). CSS will work with/help you to determine next steps. 7 | - **Not sure?** Fill out an intake as though the answer were "Yes". CSS will help you decide. 8 | 9 | *Then remove this first heading from this SUPPORT.MD file before publishing your repo.* 10 | 11 | # Support 12 | 13 | ## How to file issues and get help 14 | 15 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing 16 | issues before filing new issues to avoid duplicates. For new issues, file your bug or 17 | feature request as a new Issue. 18 | 19 | For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE 20 | FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER 21 | CHANNEL. WHERE WILL YOU HELP PEOPLE?**. 22 | 23 | ## Microsoft Support Policy 24 | 25 | Support for this **PROJECT or PRODUCT** is limited to the resources listed above. 26 | -------------------------------------------------------------------------------- /args.py: -------------------------------------------------------------------------------- 1 | import os 2 | from argparse import ArgumentParser, ArgumentTypeError 3 | from datetime import datetime 4 | 5 | 6 | def str2bool(v): 7 | if isinstance(v, bool): 8 | return v 9 | if v.lower() in ("yes", "true", "t", "y", "1"): 10 | return True 11 | elif v.lower() in ("no", "false", "f", "n", "0"): 12 | return False 13 | else: 14 | raise ArgumentTypeError("Boolean value expected.") 15 | 16 | 17 | def parse_args(): # Parse command line arguments 18 | parser = ArgumentParser(description="mlm_seq") 19 | parser.add_argument( 20 | "--use_data", default=5, type=int, help="The number of datum used" 21 | ) 22 | parser.add_argument( 23 | "--batch_size", default=1, type=int, help="Batch size for training" 24 | ) 25 | parser.add_argument( 26 | "--epochs", default=500, type=int, help="Number of epochs to train for" 27 | ) 28 | parser.add_argument( 29 | "--lr", default=3e-4, type=float, help="Learning rate for the optimizer" 30 | ) 31 | parser.add_argument( 32 | "--special_token_lr", 33 | default=1e-2, 34 | type=float, 35 | help="Learning rate for the optimizer", 36 | ) 37 | parser.add_argument( 38 | "--use_special_token", 39 | type=str2bool, 40 | nargs="?", 41 | default=True, 42 | help="true denotes using special token", 43 | ) 44 | parser.add_argument( 45 | "--use_diff_lr", 46 | type=str, 47 | default="all_embed", 48 | choices=["special", "all_embed"], 49 | help="all_embed denotes using larger learning rate for all embedding layers", 50 | ) 51 | parser.add_argument( 52 | "--use_better_init", 53 | type=str2bool, 54 | nargs="?", 55 | default=True, 56 | help="true denotes using better initialization", 57 | ) 58 | parser.add_argument( 59 | "--noise_std_dev", 60 | default=1e-3, 61 | type=float, 62 | help="the variance for the random gaussian noise", 63 | ) 64 | parser.add_argument( 65 | "--dataset", 66 | default="toy", 67 | type=str, 68 | choices=["internal", "gpt4", "wikidata", "wikidata-hallu", "toy"], 69 | help="Dataset name", 70 | ) 71 | parser.add_argument( 72 | "--weight_decay", 73 | default=1e-2, 74 | type=float, 75 | help="Weight Decay for the optimizer", 76 | ) 77 | parser.add_argument("--use_lora", type=str2bool, nargs="?", default=True) 78 | parser.add_argument("--alpha", default=0.5, type=float) 79 | parser.add_argument( 80 | "--encoder_lr", default=1e-3, type=float, help="Learning rate for the optimizer" 81 | ) 82 | parser.add_argument( 83 | "--decoder_lr", default=1e-3, type=float, help="Learning rate for the optimizer" 84 | ) 85 | parser.add_argument( 86 | "--num_fine_tune", 87 | default=100, 88 | type=int, 89 | help="Number of epochs to fine tune the pre-trained BERT", 90 | ) 91 | parser.add_argument( 92 | "--ignore_zero", 93 | type=str2bool, 94 | nargs="?", 95 | default=True, 96 | help="True denotes ignoring the padding zeros when calculating the CE loss", 97 | ) 98 | parser.add_argument( 99 | "--norm_loss", 100 | type=str2bool, 101 | nargs="?", 102 | default=True, 103 | help="Normalize the CE loss by the number of tokens in each property", 104 | ) 105 | parser.add_argument( 106 | "--neg_sample", 107 | type=str2bool, 108 | nargs="?", 109 | default=True, 110 | help="Use negative sampling during predicitng property key existence", 111 | ) 112 | parser.add_argument( 113 | "--use_pretrained", 114 | type=str2bool, 115 | nargs="?", 116 | default=True, 117 | help="True denotes using pretrained model from hugging face", 118 | ) 119 | parser.add_argument( 120 | "--only_eval", 121 | type=str2bool, 122 | nargs="?", 123 | default=False, 124 | help="Only conduct evaluation", 125 | ) 126 | parser.add_argument( 127 | "--log_wandb", 128 | type=str2bool, 129 | nargs="?", 130 | default=False, 131 | help="True denotes using wandb to log the training process", 132 | ) 133 | parser.add_argument( 134 | "--pretrained_model_name", 135 | default="t5-small", 136 | type=str, 137 | # choices=["t5-small", "t5-base", "t5-large"], 138 | help="Which base model to use", 139 | ) 140 | parser.add_argument( 141 | "--decoder_choice", 142 | default="T5", 143 | type=str, 144 | choices=["BERT", "GPT2", "T5"], 145 | help="Which decoder structure we will use as the decoder", 146 | ) 147 | parser.add_argument( 148 | "--num_optimizer", 149 | default="1", 150 | type=str, 151 | choices=["1", "2", "3"], 152 | help="Whether use separate optimizer for encoder and decoder", 153 | ) 154 | parser.add_argument( 155 | "--seed", 156 | default=42, 157 | type=int, 158 | help="Seeds helps ", 159 | ) 160 | parser.add_argument("--mode", choices=["train", "test"], default="train", type=str) 161 | parser.add_argument("--loss_mode", choices=["sum", "mean"], default="sum", type=str) 162 | parser.add_argument( 163 | "--check_epoch", default=6, type=int, help="check_epoch for evaluation" 164 | ) 165 | parser.add_argument("--top_entity", default=20, type=int, help="top_entity") 166 | parser.add_argument("--top_property", default=50, type=int, help="top_property") 167 | parser.add_argument( 168 | "--perturbation_test", 169 | default=True, 170 | type=str2bool, 171 | nargs="?", 172 | help="True denotes using a perturbed subset to test the model", 173 | ) 174 | parser.add_argument( 175 | "--perturbation_exp", 176 | default=True, 177 | type=str2bool, 178 | nargs="?", 179 | help="True denotes using the perturbation testing set, " 180 | "and False denotes using the regular testing set", 181 | ) 182 | parser.add_argument( 183 | "--training_mode", 184 | default=True, 185 | type=str2bool, 186 | nargs="?", 187 | help="False denotes loading saved model", 188 | ) 189 | parser.add_argument( 190 | "--decode_type", default="at", type=str, help="AT vs NAT for decoder" 191 | ) 192 | parser.add_argument( 193 | "--model_choice", 194 | default="musee", 195 | choices=[ 196 | "Single-mask-Multi-Entity-Step1", 197 | "Single-mask-Multi-Entity-Step2", 198 | "Single-mask-Multi-Entity-M3", 199 | "generative-llm", 200 | "M1", 201 | "E1", 202 | "musee", 203 | ], 204 | type=str, 205 | ) 206 | parser.add_argument( 207 | "--generative_model", 208 | type=str, 209 | default="t5-small", 210 | choices=["gpt2", "gpt2-large", "t5-small", "t5-base", "t5-large", "llama_3B"], 211 | help="Which llm model to use for the generative llm modeling", 212 | ) 213 | parser.add_argument( 214 | "--start_sentence", 215 | default="\n\nCreate a JSON file containing all named entities in the previous text:\n", 216 | type=str, 217 | ) 218 | parser.add_argument("--total_batch_size", default=32, type=int) 219 | parser.add_argument("--gradient_accumulation_steps", default=32, type=int) 220 | parser.add_argument("--adam_beta1", default=0.9, type=float) 221 | parser.add_argument("--adam_beta2", default=0.999, type=float) 222 | parser.add_argument("--adam_epsilon", default=1e-8, type=float) 223 | parser.add_argument("--max_grad_norm", default=0.01, type=float) 224 | parser.add_argument("--lr_scheduler_type", default="linear", type=str) 225 | parser.add_argument("--num_warmup_steps", default=90, type=int) 226 | parser.add_argument("--no_cuda", type=str2bool, nargs="?", default=False) 227 | parser.add_argument("--generate_num_return_sequences", default=1, type=int) 228 | parser.add_argument("--generate_temperature", default=0.7, type=float) 229 | parser.add_argument("--generate_top_k", default=50, type=int) 230 | parser.add_argument("--generate_top_p", default=0.95, type=float) 231 | parser.add_argument("--generate_do_sample", type=str2bool, nargs="?", default=False) 232 | parser.add_argument("--generate_num_beams", default=1, type=int) 233 | parser.add_argument("--evaluation_strategy", default="epoch", type=str) 234 | parser.add_argument("--logging_steps", default=100, type=int) 235 | parser.add_argument("--save_final_model", default=True, type=str2bool, nargs="?") 236 | parser.add_argument("--save_strategy", default="epoch", type=str) 237 | parser.add_argument("--lora_alpha", default=32, type=int) 238 | parser.add_argument("--lora_dropout", default=0.1, type=float) 239 | parser.add_argument("--lora_r", default=16, type=int) 240 | parser.add_argument("--max_length", type=int, default=2048) 241 | parser.add_argument("--output_dir", type=str, default="logs/logs") 242 | parser.add_argument( 243 | "--load_best_model_at_end", type=str2bool, nargs="?", default=False 244 | ) 245 | parser.add_argument( 246 | "--torch_dtype", default="float16", type=str, choices=["float16", "float32"] 247 | ) 248 | parser.add_argument("--lora_target_modules", default=None) 249 | parser.add_argument("--lora_modules_to_save", default=None) 250 | parser.add_argument("--tight_padding", type=str2bool, nargs="?", default=True) 251 | parser.add_argument("--save_total_limit", type=int, default=2) 252 | parser.add_argument("--saved_model_path", type=str) 253 | parser.add_argument( 254 | "--eval_batch_size", 255 | default=8, 256 | type=int, 257 | help="Batch size for LLM evaluation and output generation", 258 | ) 259 | parser.add_argument("--generate_output_path", type=str, default="generation_output") 260 | parser.add_argument("--st_checkpoint_dir", type=str, default="st_checkpoint") 261 | 262 | args = parser.parse_args() 263 | return postprocess_args(args) 264 | 265 | 266 | def postprocess_args(args): 267 | curr_date = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") 268 | output_dir = os.path.join( 269 | args.output_dir 270 | + f"_model_{args.model_choice}" 271 | + f"_{args.dataset}" 272 | + f"_lr_{str(args.lr)}" 273 | + f"_wd_{str(args.weight_decay)}" 274 | + f"_alpha_{str(args.alpha)}" 275 | + f"_lora_{str(args.use_lora)}" 276 | + f"_loss_mode_{str(args.loss_mode)}" 277 | + f"_init_{str(args.use_better_init)}" 278 | + f"_{str(args.mode)}", 279 | f"run_{curr_date}", 280 | ) 281 | args.out_dir = output_dir 282 | if not os.path.exists(output_dir): 283 | os.makedirs(output_dir) 284 | 285 | return args 286 | -------------------------------------------------------------------------------- /compute_metrics.py: -------------------------------------------------------------------------------- 1 | import json 2 | from collections import defaultdict 3 | 4 | import numpy as np 5 | from metric import compute_bipartite_matching_metrics 6 | 7 | 8 | def evaluate(ground_truth_path, prediction_path): 9 | target_count = 0 10 | target_type_counts = defaultdict(int) 11 | generated_count = 0 12 | generated_type_counts = defaultdict(int) 13 | target_entities_without_type = 0 14 | generated_entities_without_type = 0 15 | 16 | metrics = [] 17 | counts = [] 18 | nested_counts = [] 19 | 20 | with open(ground_truth_path, "r") as f: 21 | ground_truth = json.load(f) 22 | 23 | with open(prediction_path, "r") as f: 24 | predictions = json.load(f) 25 | 26 | new_dict = {} 27 | for k, v in enumerate(ground_truth.values()): 28 | new_dict[str(k)] = v 29 | ground_truth = new_dict 30 | 31 | new_dict = {} 32 | for k, v in enumerate(predictions.values()): 33 | new_dict[str(k)] = v 34 | predictions = new_dict 35 | 36 | print("Length of ground truth:", len(ground_truth)) 37 | print("Length of predictions:", len(predictions)) 38 | 39 | # Ensure both files have matching document IDs 40 | assert set(ground_truth.keys()) == set( 41 | predictions.keys() 42 | ), "Mismatch in document IDs between ground truth and predictions." 43 | 44 | exact_name_max = [] 45 | exact_name_precision = [] 46 | exact_name_recall = [] 47 | approx_name_max = [] 48 | approx_name_precision = [] 49 | approx_name_recall = [] 50 | multi_prop_max = [] 51 | multi_prop_precision = [] 52 | multi_prop_recall = [] 53 | for doc_id, doc_data in ground_truth.items(): 54 | target_entities = doc_data["entities"] 55 | target_count += len(target_entities) 56 | 57 | for entity in target_entities.values(): 58 | try: 59 | target_type_counts[entity["type"]] += 1 60 | except KeyError: 61 | target_entities_without_type += 1 62 | 63 | generated_output = predictions[doc_id]["entities"] 64 | 65 | generated_entities = list(generated_output.values()) 66 | generated_count += len(generated_entities) 67 | 68 | for entity in generated_entities: 69 | if "type" in entity: 70 | generated_type_counts[entity["type"]] += 1 71 | else: 72 | generated_entities_without_type += 1 73 | 74 | for normalization in ["Max", "Precision", "Recall"]: 75 | for measures in ["ExactName", "ApproxName", "MultiProp"]: 76 | ( 77 | final_metrics, 78 | count, 79 | nested_count, 80 | ) = compute_bipartite_matching_metrics( 81 | target_entities, 82 | generated_entities, 83 | measures=measures, 84 | normalization=normalization, 85 | establish_threshold=0.6, 86 | ) 87 | if measures == "ExactName": 88 | if normalization == "Max": 89 | exact_name_max.append(final_metrics["normalized_similarity"]) 90 | elif normalization == "Precision": 91 | exact_name_precision.append( 92 | final_metrics["normalized_similarity"] 93 | ) 94 | elif normalization == "Recall": 95 | exact_name_recall.append(final_metrics["normalized_similarity"]) 96 | elif measures == "ApproxName": 97 | if normalization == "Max": 98 | approx_name_max.append(final_metrics["normalized_similarity"]) 99 | elif normalization == "Precision": 100 | approx_name_precision.append( 101 | final_metrics["normalized_similarity"] 102 | ) 103 | elif normalization == "Recall": 104 | approx_name_recall.append( 105 | final_metrics["normalized_similarity"] 106 | ) 107 | elif measures == "MultiProp": 108 | if normalization == "Max": 109 | multi_prop_max.append(final_metrics["normalized_similarity"]) 110 | metrics.append(final_metrics) 111 | counts.append(count) 112 | nested_counts.append(nested_count) 113 | elif normalization == "Precision": 114 | multi_prop_precision.append( 115 | final_metrics["normalized_similarity"] 116 | ) 117 | elif normalization == "Recall": 118 | multi_prop_recall.append(final_metrics["normalized_similarity"]) 119 | # Compute and log metrics for generated text 120 | keys = set(key for d in metrics for key in d.keys()) 121 | quantiles = [5, 10, 25, 50, 75, 90, 95] 122 | 123 | def compute_quantiles(data, quantiles): 124 | return {q: np.percentile(data, q) for q in quantiles} 125 | 126 | avg_metrics = { 127 | key: { 128 | "average": np.mean( 129 | [metric[key] for metric in metrics if key in metric.keys()] 130 | ), 131 | "quantiles": compute_quantiles( 132 | [metric[key] for metric in metrics if key in metric.keys()], 133 | quantiles, 134 | ), 135 | "raw_data": [metric[key] for metric in metrics if key in metric.keys()], 136 | } 137 | for key in keys 138 | } 139 | avg_metrics.update( 140 | { 141 | key 142 | + "_average": np.mean( 143 | [metric[key] for metric in metrics if key in metric.keys()] 144 | ) 145 | for key in keys 146 | } 147 | ) 148 | keys = set(key for d in counts for key in d.keys()) 149 | total_metrics = { 150 | key: np.sum([count[key] for count in counts if key in count.keys()]) 151 | for key in keys 152 | } 153 | 154 | outer_keys = set(key for d in nested_counts for key in d.keys()) 155 | inner_keys = set( 156 | key 157 | for d in nested_counts 158 | for inner_dict in d.values() 159 | for key in inner_dict.keys() 160 | ) 161 | 162 | total_nested_counts = {} 163 | for k in outer_keys: 164 | total_nested_counts[k] = { 165 | key: np.sum( 166 | [ 167 | count_dict[k][key] 168 | for count_dict in nested_counts 169 | if key in count_dict[k].keys() 170 | ] 171 | ) 172 | for key in inner_keys 173 | } 174 | 175 | property_metrics = {} 176 | for k in inner_keys: 177 | property_metrics[k] = { 178 | "acc_token": ( 179 | total_nested_counts["per_property_acc_token"][k] 180 | / total_nested_counts["key_matches"][k] 181 | if total_nested_counts["key_matches"][k] > 0 182 | else 0 183 | ), 184 | "acc_aon": ( 185 | total_nested_counts["per_property_acc_aon"][k] 186 | / total_nested_counts["key_matches"][k] 187 | if total_nested_counts["key_matches"][k] > 0 188 | else 0 189 | ), 190 | "key_coverage": ( 191 | total_nested_counts["key_matches"][k] 192 | / total_nested_counts["target_key_occurance"][k] 193 | if total_nested_counts["target_key_occurance"][k] > 0 194 | else 0 195 | ), 196 | "key_precision": ( 197 | total_nested_counts["key_matches"][k] 198 | / total_nested_counts["pred_key_occurance"][k] 199 | if total_nested_counts["pred_key_occurance"][k] > 0 200 | else 0 201 | ), 202 | } 203 | print("Target Count:", target_count) 204 | print("Generated Count:", generated_count) 205 | print("Target Entities without type:", target_entities_without_type) 206 | print("Generated Entities without type:", generated_entities_without_type) 207 | 208 | avg_metrics.update( 209 | { 210 | # "avg_target_entities": avg_target_entities, 211 | # "target_type_counts": target_type_counts, 212 | # "avg_generated_entities": avg_generated_entities, 213 | # "generated_type_counts": generated_type_counts, 214 | # "avg_target_entities_without_type": avg_target_entities_without_type, 215 | # "avg_generated_entities_without_type": avg_generated_entities_without_type, 216 | "combined_coverage": total_metrics["established_entity_matches"] 217 | / total_metrics["target_entities_no_dup"], 218 | "combined_precision": total_metrics["established_entity_matches"] 219 | / total_metrics["predicted_output_entities_no_dup"], 220 | } 221 | ) 222 | 223 | avg_metrics.update(property_metrics) 224 | 225 | result = {} 226 | result["exact_name_max"] = np.mean(exact_name_max) 227 | result["exact_name_precision"] = np.mean(exact_name_precision) 228 | result["exact_name_recall"] = np.mean(exact_name_recall) 229 | result["approx_name_max"] = np.mean(approx_name_max) 230 | result["approx_name_precision"] = np.mean(approx_name_precision) 231 | result["approx_name_recall"] = np.mean(approx_name_recall) 232 | result["multi_prop_max"] = np.mean(multi_prop_max) 233 | result["multi_prop_precision"] = np.mean(multi_prop_precision) 234 | result["multi_prop_recall"] = np.mean(multi_prop_recall) 235 | result["target_count"] = target_count 236 | result["generated_count"] = generated_count 237 | result["target_type_counts"] = target_type_counts 238 | result["generated_type_counts"] = generated_type_counts 239 | result["target_entities_without_type"] = target_entities_without_type 240 | result["generated_entities_without_type"] = generated_entities_without_type 241 | result.update(avg_metrics) 242 | with open(prediction_path[:-5] + "_metrics.json", "w") as f: 243 | json.dump(result, f, indent=4) 244 | return result 245 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | 4 | class DataManager(ABC): 5 | @abstractmethod 6 | def create_dataset( 7 | self, file_path: str, use_data: int = None, max_length: int = 512, **kwargs 8 | ): 9 | # return a pytorch dataset 10 | raise NotImplementedError 11 | 12 | @abstractmethod 13 | def create_data_loader( 14 | self, 15 | file_path: str, 16 | use_data: int, 17 | max_length: int, 18 | batch_size: int, 19 | shuffle: bool, 20 | **kwargs 21 | ): 22 | # return pytorch data loaders for training, validation, test 23 | raise NotImplementedError 24 | -------------------------------------------------------------------------------- /data/dataloader_musee.py: -------------------------------------------------------------------------------- 1 | import json 2 | from copy import deepcopy 3 | 4 | import torch 5 | from torch.utils.data import DataLoader, Dataset, random_split 6 | from transformers import T5Tokenizer 7 | 8 | from . import DataManager 9 | 10 | 11 | class WikiDataStep1Manager(DataManager): 12 | class WikiDatasetStep1Filtered(Dataset): 13 | def __init__( 14 | self, 15 | data, 16 | tokenizer, 17 | max_length, 18 | top_entity=None, 19 | top_property=None, 20 | data_name=None, 21 | ): 22 | self.data = data 23 | self.tokenizer = tokenizer 24 | self.max_length = max_length 25 | # self.max_num_entity = max( 26 | # len(item.get("entities", {})) for item in data.values() 27 | # ) 28 | if data_name in ["d3", "d3-hallu"]: 29 | self.max_num_entity = 20 30 | elif data_name in ["d2", "d2-hallu"]: 31 | self.max_num_entity = 10 32 | elif data_name == "toy": 33 | self.max_num_entity = 4 34 | self.max_entities = max(len(item["entities"]) for item in data.values()) 35 | self.max_prop = max( 36 | len(item["entities"][entity]) 37 | for item in data.values() 38 | for entity in item["entities"] 39 | ) 40 | self.max_prop_length = 15 41 | 42 | # Initialize storage variables 43 | self.all_entity_types = [] 44 | self.all_pks = [] 45 | self.entity_type_counts = {} 46 | self.property_key_counts = {} 47 | 48 | self._compute_counts() 49 | 50 | # Filter by top entity and property 51 | if top_entity: 52 | self._filter_top_entities(top_entity) 53 | if top_property: 54 | self._filter_top_properties(top_property) 55 | 56 | self._filter_data() 57 | 58 | self.num_entity_types = len(self.all_entity_types) 59 | self.num_all_pks = len(self.all_pks) 60 | 61 | self._build_template() 62 | 63 | def _compute_counts(self): 64 | """Compute counts for entity types and property keys.""" 65 | for item in self.data.values(): 66 | for entity in item["entities"].values(): 67 | # Exclude the "pk_type" key 68 | self.all_pks.extend([k for k in entity.keys() if k != "pk_type"]) 69 | entity_type = entity["pk_type"] 70 | self.all_entity_types.append(entity_type) 71 | 72 | self.entity_type_counts[entity_type] = ( 73 | self.entity_type_counts.get(entity_type, 0) + 1 74 | ) 75 | for pk in entity.keys(): 76 | if pk != "pk_type": 77 | self.property_key_counts[pk] = ( 78 | self.property_key_counts.get(pk, 0) + 1 79 | ) 80 | 81 | # Sorting entity_type_counts from high to low 82 | self.entity_type_counts = dict( 83 | sorted( 84 | self.entity_type_counts.items(), 85 | key=lambda item: item[1], 86 | reverse=True, 87 | ) 88 | ) 89 | 90 | # Sorting property_key_counts from high to low 91 | self.property_key_counts = dict( 92 | sorted( 93 | self.property_key_counts.items(), 94 | key=lambda item: item[1], 95 | reverse=True, 96 | ) 97 | ) 98 | 99 | # Sorting all_entity_types based on entity_type_counts order 100 | self.all_entity_types = sorted( 101 | self.all_entity_types, 102 | key=lambda x: self.entity_type_counts[x], 103 | reverse=True, 104 | ) 105 | 106 | # Sorting all_pks based on property_key_counts order 107 | self.all_pks = sorted( 108 | self.all_pks, key=lambda x: self.property_key_counts[x], reverse=True 109 | ) 110 | 111 | def _filter_top_entities(self, top_entity): 112 | """Filter top entities.""" 113 | sorted_entities = sorted( 114 | self.entity_type_counts.items(), key=lambda x: x[1], reverse=True 115 | ) 116 | self.all_entity_types = [e[0] for e in sorted_entities[:top_entity]] 117 | print(f"top_sorted_entities counts: {dict(sorted_entities[:top_entity])}\n") 118 | 119 | def _filter_top_properties(self, top_property): 120 | """Filter top properties.""" 121 | sorted_properties = sorted( 122 | self.property_key_counts.items(), key=lambda x: x[1], reverse=True 123 | ) 124 | self.all_pks = [p[0] for p in sorted_properties[:top_property]] 125 | print( 126 | f"top_sorted_properties counts: {dict(sorted_properties[:top_property])}\n" 127 | ) 128 | 129 | def _filter_data(self): 130 | """Filter data items and entities.""" 131 | # Filter out items with no entities of desired types 132 | keys_to_remove = [ 133 | key 134 | for key, item in self.data.items() 135 | if not any( 136 | e["pk_type"] in self.all_entity_types 137 | for e in item["entities"].values() 138 | ) 139 | ] 140 | for key in keys_to_remove: 141 | del self.data[key] 142 | 143 | # Filter out entities not having any of the selected top properties 144 | for item in self.data.values(): 145 | entities_to_remove = [ 146 | entity_key 147 | for entity_key, entity in item["entities"].items() 148 | if not (set(entity.keys()) - {"pk_type"}) & set(self.all_pks) 149 | ] 150 | for entity_key in entities_to_remove: 151 | del item["entities"][entity_key] 152 | 153 | # Remove items with no entities or only with "0" type 154 | keys_to_remove = [ 155 | key 156 | for key, item in self.data.items() 157 | if not item.get("entities") 158 | or all( 159 | entity["pk_type"] == 0 for entity in item.get("entities").values() 160 | ) 161 | ] 162 | for key in keys_to_remove: 163 | del self.data[key] 164 | 165 | def _build_template(self): 166 | """Post filtering setup to determine unique entity types and property keys.""" 167 | # self.all_entity_types = sorted(list(set(self.all_entity_types))) 168 | # self.all_pks = sorted(list(set(self.all_pks))) 169 | 170 | # Create a template tensor to store property presence info 171 | self.template = torch.zeros( 172 | (self.num_entity_types, self.num_all_pks), dtype=torch.int 173 | ) 174 | for item in self.data.values(): 175 | for entity_name in item["entities"]: 176 | entity = item["entities"][entity_name] 177 | if entity["pk_type"] in self.all_entity_types: 178 | for pk in entity: 179 | if pk in self.all_pks: 180 | self.template[ 181 | self.all_entity_types.index(entity["pk_type"]) 182 | ][self.all_pks.index(pk)] = 1 183 | print("template sum by E_type:", self.template.sum(1)) 184 | print("template sum by Pk:", self.template.sum(0)) 185 | 186 | def __len__(self): 187 | return len(self.data) 188 | 189 | def id2entity(self): 190 | # index 0 is for padded (not-real) entity 191 | return {i + 1: entity for i, entity in enumerate(self.all_entity_types)} 192 | 193 | def id2property(self): 194 | return {i: property_key for i, property_key in enumerate(self.all_pks)} 195 | 196 | def get_entity_label(self): 197 | entity_label = {} 198 | for i, e in enumerate(self.all_entity_types): 199 | entity_label[e] = i + 1 200 | entity_label["padding"] = 0 201 | return entity_label 202 | 203 | def get_all_template(self): 204 | return self.template 205 | 206 | def get_template(self, entity_type: str): 207 | return self.template[self.all_entity_types.index(entity_type)] 208 | 209 | def sort_entity_values(self, entity_values): 210 | # Function to get ID for a key 211 | def get_id(key): 212 | return self.tokenizer.convert_tokens_to_ids(key) 213 | 214 | # Excluding 'pk_type' and sorting the remaining keys based on their IDs 215 | sorted_keys = sorted( 216 | [key for key in entity_values if key != "pk_type"], key=get_id 217 | ) 218 | 219 | # Constructing the sorted dictionary 220 | sorted_dict = {"pk_type": entity_values["pk_type"]} 221 | for key in sorted_keys: 222 | sorted_dict[key] = entity_values[key] 223 | 224 | return sorted_dict 225 | 226 | def __getitem__(self, idx): 227 | item = self.data[list(self.data.keys())[idx]] 228 | description = item.get("description", None) 229 | entities = item.get("entities", None) 230 | 231 | context = item["description"].replace("\n", "") 232 | 233 | context = context.strip() 234 | context = context.encode("ascii", "ignore") 235 | context = context.decode() 236 | while context.find(" ") != -1: 237 | context = context.replace(" ", " ") 238 | context = context.replace(" ,", ",") 239 | context = context.replace(" .", ".") 240 | context = context.replace(" ?", "?") 241 | context = context.replace(" !", "!") 242 | context = context.replace(" :", ":") 243 | context = context.replace(" ;", ";") 244 | context = context.replace("^", "") 245 | context = context.replace("<", "") 246 | context = context.replace(">", "") 247 | context = context.replace(" !", "!") 248 | context = context.replace("~", "") 249 | context = context.replace("\\", "") 250 | context = context.replace("\t", " ") 251 | context = context.replace("{", "") 252 | context = context.replace("}", "") 253 | context = context.strip() 254 | 255 | """ 256 | ### Obtain input_ids and attention mask 257 | """ 258 | src_text = f"{description}" 259 | src_tokenized = self.tokenizer.encode_plus( 260 | src_text, 261 | max_length=self.max_length, 262 | padding="max_length", 263 | return_attention_mask=True, 264 | return_tensors="pt", 265 | truncation=True, 266 | ) 267 | input_ids = src_tokenized["input_ids"].flatten() # (seq_len, ) 268 | attention_mask = src_tokenized["attention_mask"].flatten() # (seq_len, ) 269 | 270 | 271 | # Extracting entity names and formatting them with [SEP] token 272 | entity_names = [entity["pk_entity_name"] for entity in entities.values()] 273 | entity_names = entity_names[ 274 | : self.max_num_entity 275 | ] # only process max_num_entity 276 | real_labels_ent_name = ( 277 | " " + " ".join(entity_names) + " " 278 | ) 279 | labels_ent_name = self.tokenizer( 280 | real_labels_ent_name, 281 | max_length=self.max_num_entity * 6, 282 | padding="max_length", 283 | truncation=True, 284 | add_special_tokens=True, 285 | return_tensors="pt", 286 | )["input_ids"].squeeze(0) 287 | 288 | # print("real_labels_ent_name:", real_labels_ent_name) 289 | # print("labels_ent_name:", labels_ent_name) 290 | 291 | # Create the attention mask for the entity positions 292 | # attention_mask_ent = (labels_ent != -1).long() 293 | # attention_mask_ent_tokenized = (labels_ent_tokenized != 0).long() 294 | attention_mask_ent_name = (labels_ent_name != 0).long() 295 | 296 | """ 297 | ### Labels for Step2: Obtain labels for Entity type and property key 298 | """ 299 | end_token_id = self.tokenizer.eos_token_id 300 | labels_pk = torch.full( 301 | (self.max_num_entity, self.num_all_pks + 2), self.tokenizer.pad_token_id 302 | ) # (max_num_Entity, num_all_pks+2) 303 | 304 | for i, (_, entity_values) in enumerate(entities.items()): 305 | entity_values = self.sort_entity_values(entity_values) 306 | if ( 307 | i >= self.max_num_entity 308 | ): # We only process up to max_num_entity entities 309 | break 310 | 311 | # Get the special token ID for the entity type 312 | ent_type_id = self.tokenizer.convert_tokens_to_ids( 313 | entity_values["pk_type"] 314 | ) 315 | labels_pk[i, 0] = ent_type_id 316 | 317 | # Use a counter for the position in the label_ids tensor 318 | pk_counter = 1 319 | 320 | # Iterate over properties 321 | for pk in entity_values: 322 | if pk not in [ 323 | "pk_type", 324 | "pk_entity_name", 325 | ]: # type is already added. No need to predict ent_name 326 | # Use the property key as a special token to get its ID 327 | pk_id = self.tokenizer.convert_tokens_to_ids(pk) 328 | if ( 329 | pk_counter < labels_pk.size(1) - 1 330 | ): # Leave space for the end token 331 | labels_pk[i, pk_counter] = pk_id 332 | pk_counter += 1 333 | 334 | # Add the end token if there's at least one property key 335 | if pk_counter < labels_pk.size(1): 336 | labels_pk[i, pk_counter] = end_token_id 337 | 338 | # The attention mask is binary: 1 for special tokens, 0 for padding tokens 339 | attention_mask_pk = ( 340 | labels_pk != self.tokenizer.pad_token_id 341 | ).long() # (max_num_Entity, num_all_pks+2) 342 | 343 | """ 344 | ### Labels for Step3: Property values (max_num_Entity, num_all_pks, max_prop_len) 345 | """ 346 | labels_pv = torch.full( 347 | (self.max_num_entity, self.num_all_pks, self.max_prop_length), 348 | self.tokenizer.pad_token_id, 349 | ) # Initialize with padding token ID 350 | 351 | for i, (_, entity_values) in enumerate(entities.items()): 352 | entity_values = self.sort_entity_values(entity_values) 353 | if ( 354 | i >= self.max_num_entity 355 | ): # Only process up to max_num_entity entities 356 | break 357 | 358 | # Counter for the property key 359 | pk_counter = 0 360 | for pk in entity_values: 361 | if pk not in [ 362 | "pk_type", 363 | "pk_entity_name", 364 | ]: # type is already added. No need to predict ent_name 365 | # Encode the property value 366 | encoded_prop = self.tokenizer.encode( 367 | entity_values[pk] 368 | + self.tokenizer.eos_token, # Encode the property value 369 | add_special_tokens=False, 370 | max_length=self.max_prop_length, 371 | padding="max_length", 372 | truncation=True, 373 | return_tensors="pt", 374 | ).flatten() 375 | # Place the encoded property value in the tensor 376 | labels_pv[i, pk_counter, :] = encoded_prop 377 | pk_counter += 1 378 | 379 | # Create a mask for the encoded properties 380 | attention_mask_pv = (labels_pv != self.tokenizer.pad_token_id).long() 381 | 382 | return { 383 | "input_ids": input_ids, # (seq_len, ) 384 | # "labels_ent": labels_ent, # (max_num_Entity * 2, ) 385 | # "labels_ent_tokenized": labels_ent_tokenized, # (max_num_Entity * 3, ) 386 | "labels_ent_name": labels_ent_name, # (max_num_Entity * 6, ) 387 | "real_labels_ent_name": real_labels_ent_name, # (max_num_Entity * 6, ) 388 | "labels_pk": labels_pk, # (max_num_Entity, num_all_pks+2) 389 | "labels_pv": labels_pv, # (max_num_Entity, num_all_pks, max_prop_len) 390 | "attention_mask": attention_mask, 391 | # "attention_mask_ent": attention_mask_ent, 392 | # "attention_mask_ent_tokenized": attention_mask_ent_tokenized, 393 | "attention_mask_ent_name": attention_mask_ent_name, 394 | "attention_mask_pk": attention_mask_pk, 395 | "attention_mask_pv": attention_mask_pv, 396 | } 397 | 398 | def create_dataset( 399 | self, file_path: str, use_data: int = None, max_length: int = 1024, **kwargs 400 | ): 401 | pretrained_model_name = kwargs.get("model_name", None) 402 | tokenizer = T5Tokenizer.from_pretrained(pretrained_model_name) 403 | special_tokens_need_to_add = [] 404 | ent_type_tokens = [] # List for tokens starting with "ent_type_" 405 | pk_tokens = [] # List for tokens starting with "pk_" 406 | 407 | with open(file_path, "r") as f: 408 | data = json.load(f) 409 | 410 | # Modify the keys in the data 411 | for k, entities in data.items(): 412 | for entity_id, entity in entities["entities"].items(): 413 | new_entity = {} 414 | for prop_key, prop_value in entity.items(): 415 | if prop_key == "type": 416 | new_key = f"pk_{prop_key}" 417 | new_value = f'ent_type_{prop_value.replace(" ", "_")}' 418 | new_entity[new_key] = new_value 419 | special_tokens_need_to_add.append(new_value) 420 | ent_type_tokens.append(new_value) # Add to ent_type_tokens 421 | # elif prop_key != "entity name": # we do not need to train / predict entity name for pk and pv 422 | else: 423 | new_key = f'pk_{prop_key.replace(" ", "_")}' 424 | new_entity[new_key] = prop_value 425 | special_tokens_need_to_add.append(new_key) 426 | pk_tokens.append(new_key) # Add to pk_tokens 427 | entities["entities"][entity_id] = new_entity 428 | 429 | special_tokens_need_to_add = sorted(list(set(special_tokens_need_to_add))) 430 | ent_type_tokens = sorted( 431 | list(set(ent_type_tokens)) 432 | ) # Remove duplicates and sort 433 | pk_tokens = sorted(list(set(pk_tokens))) # Remove duplicates and sort 434 | # print("special_tokens_need_to_add:", len(special_tokens_need_to_add), special_tokens_need_to_add) 435 | tokenizer.add_tokens(special_tokens_need_to_add) 436 | 437 | print("Full data length:", len(data)) 438 | data = dict(list(data.items())[:use_data]) 439 | print("Used data length:", len(data)) 440 | 441 | top_entity = kwargs.get("top_entity", None) 442 | top_property = kwargs.get("top_property", None) 443 | data_name = kwargs.get("data_name", None) 444 | dataset = self.WikiDatasetStep1Filtered( 445 | data, 446 | tokenizer, 447 | max_length=max_length, 448 | top_entity=top_entity, 449 | top_property=top_property, 450 | data_name=data_name, 451 | ) 452 | print("Max_num_entity:", dataset.max_num_entity) 453 | print("Max_num_pk:", dataset.num_all_pks) 454 | print("Max_prop_len:", dataset.max_prop_length) 455 | print("*************") 456 | 457 | return ( 458 | dataset, 459 | tokenizer, 460 | special_tokens_need_to_add, 461 | ent_type_tokens, 462 | pk_tokens, 463 | ) 464 | 465 | def create_data_loader( 466 | self, 467 | file_path: str, 468 | use_data: int, 469 | max_length: int, 470 | batch_size: int, 471 | shuffle: bool, 472 | **kwargs, 473 | ): 474 | dataset, tokenizer, _, _, _ = self.create_dataset( 475 | file_path, use_data, max_length, **kwargs 476 | ) 477 | 478 | if len(dataset) < 10: 479 | train_size = len(dataset) - 2 480 | val_size = 1 481 | test_size = 1 482 | else: 483 | train_size = int(0.8 * len(dataset)) 484 | val_size = int(0.1 * len(dataset)) 485 | test_size = len(dataset) - train_size - val_size 486 | 487 | train_dataset, val_dataset, test_dataset = random_split( 488 | dataset, [train_size, val_size, test_size] 489 | ) 490 | 491 | train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle) 492 | val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle) 493 | test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=shuffle) 494 | print("train_dataloader:", len(train_loader)) 495 | print("val_dataloader:", len(val_loader)) 496 | print("test_dataloader:", len(test_loader)) 497 | 498 | return train_loader, val_loader, test_loader, dataset, tokenizer 499 | -------------------------------------------------------------------------------- /experiment_musee.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import torch 4 | import wandb 5 | from args import parse_args 6 | from metrics import evaluate 7 | from peft import LoraConfig, PeftModel, get_peft_model 8 | from syne_tune import Reporter 9 | from torch.utils.data import DataLoader 10 | from transformers import get_linear_schedule_with_warmup 11 | from utils import get_attention_paths, print_trainable_parameters, set_seed 12 | 13 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 14 | 15 | 16 | def generate_final_json(results, ground_truth_path): 17 | # Load the ground truth JSON file 18 | with open(ground_truth_path, "r") as file: 19 | ground_truth_data = json.load(file) 20 | 21 | # Function to process each prediction and return formatted entities 22 | def process_prediction(ent_tokens, pk_tokens, pv_tokens): 23 | print("ent_tokens, pk_tokens, pv_tokens:", ent_tokens, pk_tokens, pv_tokens) 24 | entities = {} 25 | for i, (ent_name, pk_token, pv_token) in enumerate( 26 | zip(ent_tokens, pk_tokens, pv_tokens) 27 | ): 28 | pk_parts = pk_token.split() 29 | 30 | # Skip processing if pk_parts is empty 31 | if len(pk_parts) == 0: 32 | continue 33 | 34 | # Extracting entity type 35 | entity_type = ( 36 | pk_parts[0].replace("ent_type_", "").replace("_", " ") 37 | if "ent_type_" in pk_parts[0] 38 | else "unknown" 39 | ) 40 | 41 | entity_info = {"type": entity_type} 42 | entity_info["entity name"] = ent_name # No need to predict entity name 43 | for j, key in enumerate( 44 | pk_parts[1:] 45 | ): # Skip the first token, which is the type 46 | prop_key = key.replace("pk_", "").replace("_", " ") 47 | if "ent type" in prop_key: 48 | continue 49 | entity_info[prop_key] = pv_token[j] if j < len(pv_token) else "" 50 | # if entity_type == "human": 51 | # entity_info["given name"] = ent_name.split()[0] 52 | # entity_info["family name"] = ent_name.split()[-1] 53 | 54 | entities[str(i)] = entity_info 55 | return entities 56 | 57 | # Create the final JSON object 58 | final_json = {} 59 | for doc_id, prediction in zip(ground_truth_data, results): 60 | ent_tokens = prediction.get("predict_ent", []) 61 | pk_tokens = prediction.get("predict_pk", []) 62 | pv_tokens = prediction.get("predict_pv", []) 63 | 64 | entities = process_prediction(ent_tokens, pk_tokens, pv_tokens) 65 | final_json[doc_id] = { 66 | "doc_id": doc_id, 67 | "description": ground_truth_data[doc_id]["description"], 68 | "entities": entities, 69 | } 70 | 71 | return final_json 72 | 73 | 74 | def experiment(args): 75 | # Due to the forced_decoder_ids does not support batch, we have to set batch_size=1 for inference 76 | if args.mode == "test": 77 | args.batch_size = 1 78 | 79 | if args.model_choice == "MuSEE": 80 | from trainer.trainer_musee import Trainer_E_Pk_Pv 81 | 82 | trainer = Trainer_E_Pk_Pv() 83 | if args.log_wandb: 84 | wandb.login() 85 | from data.dataloader_musee import WikiDataStep1Manager 86 | 87 | manager = WikiDataStep1Manager() 88 | 89 | if args.dataset == "toy": 90 | data_abbrev = "toy" 91 | train_data_path = "data/toy/D3_toy.json" 92 | val_data_path = "data/toy/D3_toy.json" 93 | test_data_path = "data/toy/D3_toy.json" 94 | # test_data_path = "data/toy/dummy.json" 95 | use_data = 100 96 | max_length = 512 97 | elif args.dataset == "gpt4": 98 | data_abbrev = "d2" 99 | train_data_path = "data/D2_final/D2_train_final.json" 100 | val_data_path = "data/D2_final/D2_val_final.json" 101 | test_data_path = "data/D2_final/D2_test_final.json" 102 | use_data = 20000 103 | max_length = 512 104 | elif args.dataset == "wikidata": 105 | data_abbrev = "d3" 106 | train_data_path = "data/D3_final/D3_train_final.json" 107 | val_data_path = "data/D3_final/D3_val_final.json" 108 | test_data_path = "data/D3_final/D3_test_final.json" 109 | use_data = 20000 110 | max_length = 512 111 | elif args.dataset == "wikidata-hallu": 112 | data_abbrev = "d3" 113 | train_data_path = "data/D3_final/D3_train_final.json" 114 | val_data_path = "data/D3_final/D3_val_final.json" 115 | test_data_path = "data/D3_final/D3_test_final_hallu_5k.json" 116 | use_data = 20000 117 | max_length = 512 118 | elif args.dataset == "internal": 119 | data_abbrev = "inter" 120 | raise NotImplementedError 121 | 122 | # Get dataloader 123 | dataset, tokenizer, special_tokens_need_to_add, ent_type_tokens, pk_tokens = ( 124 | manager.create_dataset( 125 | file_path=train_data_path, 126 | model_name=args.pretrained_model_name, 127 | use_data=use_data, 128 | max_length=max_length, 129 | batch_size=args.batch_size, 130 | shuffle=False, 131 | if_filter=True, 132 | top_entity=args.top_entity, 133 | top_property=args.top_property, 134 | data_name=data_abbrev, 135 | ) 136 | ) 137 | val_dataset, _, _, _, _ = manager.create_dataset( 138 | file_path=val_data_path, 139 | model_name=args.pretrained_model_name, 140 | use_data=use_data, 141 | max_length=max_length, 142 | batch_size=args.batch_size, 143 | shuffle=False, 144 | if_filter=True, 145 | top_entity=args.top_entity, 146 | top_property=args.top_property, 147 | data_name=data_abbrev, 148 | ) 149 | test_dataset, _, _, _, _ = manager.create_dataset( 150 | file_path=test_data_path, 151 | model_name=args.pretrained_model_name, 152 | use_data=use_data, 153 | max_length=max_length, 154 | batch_size=args.batch_size, 155 | shuffle=False, 156 | if_filter=True, 157 | top_entity=args.top_entity, 158 | top_property=args.top_property, 159 | data_name=data_abbrev, 160 | ) 161 | # Get the indices of the new tokens 162 | added_new_token_ids = tokenizer.convert_tokens_to_ids( 163 | special_tokens_need_to_add 164 | ) 165 | added_ent_type_tokens = tokenizer.convert_tokens_to_ids(ent_type_tokens) 166 | added_pk_tokens = tokenizer.convert_tokens_to_ids(pk_tokens) 167 | print("added_new_token_ids:", added_new_token_ids) 168 | print("added_ent_type_tokens:", added_ent_type_tokens) 169 | print("added_pk_tokens:", added_pk_tokens) 170 | 171 | train_dataloader = DataLoader( 172 | dataset, batch_size=args.batch_size, shuffle=False 173 | ) 174 | val_dataloader = DataLoader( 175 | val_dataset, batch_size=args.batch_size, shuffle=False 176 | ) 177 | test_dataloader = DataLoader( 178 | test_dataset, batch_size=args.batch_size, shuffle=False 179 | ) 180 | 181 | max_seq_length = dataset.max_length 182 | num_entity_types = dataset.num_entity_types 183 | max_num_entity = dataset.max_num_entity 184 | num_property_keys = dataset.num_all_pks 185 | all_entity_types = dataset.all_entity_types 186 | entity_type_counts = dataset.entity_type_counts 187 | entity_type_counts["0"] = use_data * max_num_entity - sum( 188 | entity_type_counts.values() 189 | ) 190 | entity_type_counts = { 191 | k: v 192 | for k, v in sorted( 193 | entity_type_counts.items(), key=lambda item: item[1], reverse=True 194 | ) 195 | } 196 | property_key_counts = dataset.property_key_counts 197 | print("-----------") 198 | print("max_seq_length:", max_seq_length) 199 | print("num_entity_types:", num_entity_types) 200 | print("max_num_entity:", max_num_entity) 201 | print("num_property_keys:", num_property_keys) 202 | print("entity_type_counts:", len(entity_type_counts), entity_type_counts) 203 | print("property_key_counts:", property_key_counts) 204 | 205 | # type_weights = compute_inverse_frequency_weights(entity_type_counts, num_entity_types).to(device) 206 | # print("type_weights:", type_weights) 207 | 208 | # original_template = dataset.get_all_template().numpy() 209 | # all_zero_row = np.zeros( 210 | # original_template.shape[1], dtype=original_template.dtype 211 | # ) 212 | # template = np.vstack( 213 | # (all_zero_row, original_template) 214 | # ) # add all-zero row for type 0 215 | # template = torch.tensor(template).to(device) 216 | # print("template:", template.shape) 217 | from trainer.trainer_musee import Predictor_E_Pk_Pv 218 | 219 | model = Predictor_E_Pk_Pv( 220 | pretrained_model_name=args.pretrained_model_name, 221 | max_seq_length=max_seq_length, 222 | max_num_entity=max_num_entity, 223 | tokenizer=tokenizer, 224 | ).to(device) 225 | model.t5_model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=32) 226 | print_trainable_parameters(model) 227 | 228 | mask_token, sep_token = "", "" 229 | mask_token_id = torch.tensor( 230 | tokenizer.encode(mask_token, add_special_tokens=False)[0] 231 | ).item() 232 | sep_token_id = torch.tensor( 233 | tokenizer.encode(sep_token, add_special_tokens=False)[0] 234 | ).item() 235 | 236 | vocab_size = model.t5_model.get_input_embeddings().weight.size(0) 237 | 238 | print("vocab_size:", vocab_size) 239 | print("mask_token_id:", mask_token_id) 240 | print("sep_token_id:", sep_token_id) 241 | print("--------------------") 242 | 243 | # # Set up wandb 244 | if args.log_wandb: 245 | run_name = "lr{}-wd{}-{}".format(args.lr, args.weight_decay, args.loss_mode) 246 | wandb.init( 247 | project="MuSEE-full-{}-{}-{}-lora-{}-init-{}".format( 248 | data_abbrev, 249 | args.pretrained_model_name, 250 | args.use_lora, 251 | args.loss_mode, 252 | args.use_better_init, 253 | ), 254 | config=args, 255 | name=run_name, # set the run name here 256 | ) 257 | 258 | save_path = ( 259 | f"saved/best_model/MuSEE/MuSEE_{data_abbrev}_m{args.pretrained_model_name}_" 260 | f"lr{args.lr}_wd{args.weight_decay}_{args.loss_mode}_lora_{args.use_lora}_init_{args.use_better_init}" 261 | f"best_model" 262 | ) 263 | 264 | if args.use_better_init: 265 | print("Better initialize the special tokens' embeddings") 266 | print( 267 | "special_tokens_need_to_add:", 268 | len(special_tokens_need_to_add), 269 | special_tokens_need_to_add, 270 | ) 271 | # Get the embeddings layer from the model 272 | embedding_layer = model.t5_model.get_input_embeddings() 273 | print( 274 | "old:", 275 | model.t5_model.get_input_embeddings().weight.shape, 276 | model.t5_model.get_input_embeddings().weight.sum(), 277 | ) 278 | 279 | # Calculate new embeddings 280 | new_token_embeddings = [] 281 | for token in special_tokens_need_to_add: 282 | # Tokenize the special token into subwords 283 | token = token.replace("ent_type", "") 284 | token = token.replace("pk", "") 285 | token = token.replace("_", " ") 286 | subtokens = tokenizer.tokenize(token) 287 | 288 | # Get the embeddings for the subtokens 289 | subtoken_ids = tokenizer.convert_tokens_to_ids(subtokens) 290 | subtoken_embeddings = embedding_layer.weight[subtoken_ids] 291 | 292 | # Calculate the average embedding 293 | average_embedding = subtoken_embeddings.mean(dim=0) 294 | 295 | # Add Gaussian noise to the average embedding 296 | noise = torch.randn(average_embedding.size()) * args.noise_std_dev 297 | new_embedding = average_embedding + noise.to(device) 298 | 299 | # Append to the list of new token embeddings 300 | new_token_embeddings.append(new_embedding) 301 | 302 | # Convert the list to a tensor 303 | new_token_embeddings = torch.stack(new_token_embeddings) 304 | 305 | # Set the embeddings for the new tokens in the model 306 | with torch.no_grad(): 307 | # Get the indices of the new tokens 308 | new_token_ids = tokenizer.convert_tokens_to_ids( 309 | special_tokens_need_to_add 310 | ) 311 | # Update the embeddings for these tokens 312 | embedding_layer.weight[new_token_ids] = new_token_embeddings 313 | print( 314 | "new:", 315 | model.t5_model.get_input_embeddings().weight.shape, 316 | model.t5_model.get_input_embeddings().weight.sum(), 317 | ) 318 | 319 | # Set embedding layer as trainable 320 | model.t5_model.shared.weight.requires_grad = True 321 | 322 | if args.mode == "train": 323 | if args.use_lora: 324 | target_modules = get_attention_paths(model) 325 | modules_to_save = ["shared"] 326 | 327 | lora_config = LoraConfig( 328 | r=args.lora_r, 329 | lora_alpha=args.lora_alpha, 330 | lora_dropout=args.lora_dropout, 331 | target_modules=target_modules, 332 | modules_to_save=modules_to_save, 333 | ) 334 | 335 | model = get_peft_model(model, lora_config) 336 | print_trainable_parameters(model) 337 | 338 | optimizer = torch.optim.AdamW( 339 | model.parameters(), lr=args.lr, weight_decay=args.weight_decay 340 | ) 341 | # Set up the learning rate scheduler 342 | total_steps = len(train_dataloader) * args.epochs 343 | scheduler = get_linear_schedule_with_warmup( 344 | optimizer, num_warmup_steps=0, num_training_steps=total_steps 345 | ) 346 | report = Reporter() 347 | trainer.train( 348 | save_path, 349 | model, 350 | train_dataloader, 351 | val_dataloader, 352 | optimizer, 353 | scheduler, 354 | args.epochs, 355 | device=device, 356 | log_wandb=args.log_wandb, 357 | use_lora=args.use_lora, 358 | alpha=args.alpha, 359 | added_ent_type_tokens=added_ent_type_tokens, 360 | added_pk_tokens=added_pk_tokens, 361 | loss_mode=args.loss_mode, 362 | reporter=report, 363 | ) 364 | 365 | if args.use_lora: 366 | print( 367 | f"t5_model.shared.original_module", 368 | model.t5_model.shared.original_module.weight.sum(), 369 | ) 370 | print( 371 | f"t5_model.shared.modules_to_save", 372 | model.t5_model.shared.modules_to_save["default"].weight.sum(), 373 | ) 374 | 375 | elif args.mode == "test": 376 | save_path = ( 377 | f"saved/best_model/MuSEE/MuSEE_{data_abbrev}_m{args.pretrained_model_name}_" 378 | f"lr{args.lr}_wd{args.weight_decay}_{args.loss_mode}_lora_{args.use_lora}_init_{args.use_better_init}" 379 | f"best_model" 380 | ) 381 | print("save_path:", save_path) 382 | if args.use_lora: 383 | model = PeftModel.from_pretrained(model, save_path) 384 | print( 385 | f"t5_model.shared.original_module", 386 | model.t5_model.shared.original_module.weight.sum(), 387 | ) 388 | print( 389 | f"t5_model.shared.modules_to_save", 390 | model.t5_model.shared.modules_to_save["default"].weight.sum(), 391 | ) 392 | model = model.merge_and_unload() 393 | else: 394 | model.load_state_dict( 395 | torch.load(f"{save_path}.pt", map_location=device) 396 | ) 397 | 398 | print( 399 | "after load pretrained (get_input_embeddings):", 400 | model.t5_model.get_input_embeddings().weight.shape, 401 | model.t5_model.get_input_embeddings().weight.sum(), 402 | ) 403 | 404 | print( 405 | "after load pretrained (shared):", 406 | model.t5_model.shared.weight.shape, 407 | model.t5_model.shared.weight.sum(), 408 | ) 409 | 410 | model.eval() 411 | # generate json output 412 | # id2entity, id2property = dataset.id2entity(), dataset.id2property() 413 | results = Trainer_E_Pk_Pv.generate_full_json_output( 414 | model, 415 | test_dataloader, 416 | added_ent_type_tokens, 417 | added_pk_tokens, 418 | tokenizer, 419 | device, 420 | mode=args.mode, 421 | ) 422 | 423 | final_json = generate_final_json(results, test_data_path) 424 | print("final_json:", json.dumps(final_json, indent=4)) 425 | 426 | # Save to a JSON file 427 | prediction_path = ( 428 | f"saved/best_model/MuSEE/saved_json/MuSEE_{data_abbrev}_m{args.pretrained_model_name}_" 429 | f"lr{args.lr}_wd{args.weight_decay}_{args.loss_mode}_lora_{args.use_lora}_init_{args.use_better_init}.json" 430 | ) 431 | with open(prediction_path, "w", encoding="utf-8") as file: 432 | json.dump(final_json, file, ensure_ascii=False, indent=4) 433 | 434 | metrics = evaluate(test_data_path, prediction_path) 435 | 436 | 437 | def run(): 438 | args = parse_args() 439 | set_seed(args.seed) 440 | experiment(args) 441 | 442 | 443 | if __name__ == "__main__": 444 | run() 445 | -------------------------------------------------------------------------------- /img/metric.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/Structured-Entity-Extraction/8e3d0fbb009874b5ba35a450de1ff37995b8102e/img/metric.png -------------------------------------------------------------------------------- /img/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/Structured-Entity-Extraction/8e3d0fbb009874b5ba35a450de1ff37995b8102e/img/model.png -------------------------------------------------------------------------------- /metric.py: -------------------------------------------------------------------------------- 1 | import json 2 | from copy import deepcopy 3 | 4 | import torch 5 | from scipy.sparse import csr_matrix 6 | from scipy.sparse.csgraph import min_weight_full_bipartite_matching 7 | 8 | # from utils import remove_duplicates_and_postprocess 9 | 10 | 11 | def compute_distances(gt, pred, measures, weights_pks=None): 12 | """ 13 | Args: 14 | gt: list of ground-truth entities, where each entity has 15 | - a 't' field consisting of the one-hot type torch.tensor 16 | - a 'pk' field consisting of multihot torch.tensor indicating the present pks 17 | - a 'pv' field containing a list of the property values, in an order consistent with the 'pk' field. 18 | pred: list of predicted entities 19 | measures: list of distance measures to apply 20 | weights_pks: torch.tensor defining a weight for each property key to compute weighted averages in the distance metrics. 21 | """ 22 | assert measures in ["ExactName", "ApproxName", "MultiProp"] 23 | 24 | # if weights_pks is None: 25 | # weights_pks = torch.ones_like(gt[0]["pk"]) 26 | 27 | # def entity_type_distance_ce(e1, e2): 28 | # t1, t2 = e1["t"].float(), e2["t"].float() 29 | # epsilon = 1e-5 30 | # t2 = torch.clamp(t2, min=epsilon) # Ensure values are not too close to zero 31 | # return -torch.sum(t1 * torch.log(t2)) 32 | # 33 | # def entity_type_distance_acc(e1, e2): 34 | # acc = (torch.argmax(e1["t"]) == torch.argmax(e2["t"])).float() 35 | # return 1 - acc 36 | def entity_name_distance_approx( 37 | e1, e2, weights 38 | ): # save as pv but only weights for name is 1 rest is 0 39 | v1, v2 = e1["pv"], e2["pv"] 40 | # Use jaccard similarity to compute the distance of the name 41 | # Split each property value into tokens (words) 42 | tokens_v1 = [ 43 | set(value.lower().split()) 44 | for index, value in enumerate(v1) 45 | if weights[index] == 1 46 | ] # only for name 47 | tokens_v2 = [ 48 | set(value.lower().split()) 49 | for index, value in enumerate(v2) 50 | if weights[index] == 1 51 | ] # only for name 52 | jaccard_similarities = [] 53 | _weights = weights.clone() 54 | for i, (t1, t2) in enumerate(zip(tokens_v1, tokens_v2)): 55 | # Compute the Jaccard similarity for the token sets 56 | intersection_size = len(t1.intersection(t2)) 57 | union_size = len(t1.union(t2)) 58 | if union_size == 0: 59 | jaccard_sim = 0.0 60 | _weights[i] = 0.0 61 | else: 62 | jaccard_sim = intersection_size / union_size 63 | jaccard_similarities.append(jaccard_sim) 64 | dist = ( 65 | 1 - (torch.tensor(jaccard_similarities) * _weights).sum() / _weights.sum() 66 | ) 67 | return dist 68 | 69 | def entity_name_distance_exact( 70 | e1, e2, weights 71 | ): # save as pv but only weights for name is 1 rest is 0 72 | v1, v2 = e1["pv"], e2["pv"] 73 | v1 = [ 74 | value for index, value in enumerate(v1) if weights[index] == 1 75 | ] # only for name 76 | v2 = [ 77 | value for index, value in enumerate(v2) if weights[index] == 1 78 | ] # only for name 79 | matching = torch.tensor([int(a.lower() == b.lower()) for a, b in zip(v1, v2)]) 80 | if len(matching) == 0: 81 | return 1.0 82 | else: 83 | return 1 - (matching.float() * weights).sum() / weights.sum() 84 | 85 | # def property_key_distance_bce(e1, e2): 86 | # k1, k2 = e1["pk"].float(), e2["pk"].float() 87 | # epsilon = 1e-5 88 | # k2 = torch.clamp( 89 | # k2, min=epsilon, max=1 - epsilon 90 | # ) # Ensures values are between epsilon and 1-epsilon 91 | # return -torch.sum(k1 * torch.log(k2) + (1 - k1) * torch.log(1 - k2)) 92 | # 93 | # def property_key_distance_acc(e1, e2): 94 | # k1, k2 = e1["pk"].float(), e2["pk"].float() 95 | # k2_preds = (k2 >= 0.5).float() 96 | # corrects = (k2_preds == k1).float().sum() 97 | # acc = corrects / len(k1) 98 | # return 1 - acc 99 | # 100 | # def property_value_prop_distance_acc(e1, e2, weights): 101 | # v1, v2 = e1["pv"], e2["pv"] 102 | # 103 | # matching = torch.tensor([int(a.lower() == b.lower()) for a, b in zip(v1, v2)]) 104 | # 105 | # return 1 - (matching.float() * weights).sum() / weights.sum() 106 | 107 | def property_value_token_distance_acc(e1, e2, weights): 108 | v1, v2 = e1["pv"], e2["pv"] 109 | 110 | # Split each property value into tokens (words) 111 | tokens_v1 = [set(value.lower().split()) for value in v1] 112 | tokens_v2 = [set(value.lower().split()) for value in v2] 113 | 114 | jaccard_similarities = [] 115 | _weights = weights.clone() 116 | for i, (t1, t2) in enumerate(zip(tokens_v1, tokens_v2)): 117 | # Compute the Jaccard similarity for the token sets 118 | intersection_size = len(t1.intersection(t2)) 119 | union_size = len(t1.union(t2)) 120 | if union_size == 0: 121 | jaccard_sim = 0.0 122 | _weights[i] = 0.0 123 | else: 124 | jaccard_sim = intersection_size / union_size 125 | jaccard_similarities.append(jaccard_sim) 126 | dist = ( 127 | 1 - (torch.tensor(jaccard_similarities) * _weights).sum() / _weights.sum() 128 | ) 129 | return dist 130 | 131 | distances = torch.zeros((len(gt), len(pred))) 132 | 133 | for i, g in enumerate(gt): 134 | for j, p in enumerate(pred): 135 | distance = 0 136 | if measures == "ExactName": 137 | distance = entity_name_distance_exact(g, p, weights_pks) 138 | elif measures == "ApproxName": 139 | distance = entity_name_distance_approx(g, p, weights_pks) 140 | elif measures == "MultiProp": # We can also use bce here 141 | distance = property_value_token_distance_acc(g, p, weights_pks) 142 | # if "E-CE" in measures: 143 | # distance += entity_type_distance_ce(g, p) 144 | # if "E-ACC" in measures: 145 | # distance += entity_type_distance_acc(g, p) 146 | # if "Pk-BCE" in measures: 147 | # distance += property_key_distance_bce(g, p) 148 | # if "Pk-ACC" in measures: 149 | # distance += property_key_distance_acc(g, p) 150 | # if "Pv-prop-ACC" in measures: 151 | # distance += property_value_prop_distance_acc(g, p, weights_pks) 152 | # if "Pv-token-ACC" in measures: 153 | # distance += property_value_token_distance_acc(g, p, weights_pks) 154 | distances[i, j] = distance 155 | 156 | return distances 157 | 158 | 159 | def bipartite_matching(distances): 160 | # Max is the maximum size of ground-truth set size and prediction set size 161 | # Precision is the size of the prediction set size 162 | # Recall is the size of the ground-truth set size 163 | biadjacency_matrix = csr_matrix(distances.numpy()) 164 | # Add a constant (e.g., 1) to every distance to ensure no zero values 165 | biadjacency_matrix = biadjacency_matrix + csr_matrix( 166 | torch.ones_like(distances).numpy() 167 | ) 168 | # print("biadjacency_matrix:", biadjacency_matrix.todense()) 169 | row_ind, col_ind = min_weight_full_bipartite_matching( 170 | biadjacency_matrix, maximize=False 171 | ) 172 | 173 | # Subtract the added constant for each matched pair 174 | min_num_entity = min(biadjacency_matrix.shape[0], biadjacency_matrix.shape[1]) 175 | max_num_entity = max(biadjacency_matrix.shape[0], biadjacency_matrix.shape[1]) 176 | matched_distances = biadjacency_matrix[row_ind, col_ind].sum() - min_num_entity 177 | # # optimal_metric_loss = ( 178 | # # (matched_distances + max_num_entity - min_num_entity) / max_num_entity 179 | # # if max_num_entity != 0 180 | # # else 0 181 | # # ) 182 | # optimal_metric_loss = (matched_distances + denominator - min_num_entity) / denominator if denominator != 0 else 0 183 | # Obtain permutation 184 | permutation_ground_truth = torch.tensor(row_ind)[ 185 | torch.argsort(torch.tensor(col_ind)) 186 | ] 187 | permutation_prediction = torch.tensor(col_ind) 188 | return permutation_ground_truth, permutation_prediction 189 | # return optimal_metric_loss, permutation_ground_truth, permutation_prediction 190 | 191 | 192 | def compute_bipartite_matching_metrics( 193 | target: list, 194 | predicted_output: list, 195 | measures, 196 | normalization, 197 | establish_threshold=0.6, 198 | ): 199 | """Compute metrics based on bipartite matching""" 200 | assert normalization in ["Max", "Precision", "Recall"] 201 | target = deepcopy(target) 202 | predicted_output = deepcopy(predicted_output) 203 | target_set_size = len(target) 204 | predicted_output_set_size = len(predicted_output) 205 | if isinstance(target, dict): 206 | target = [entity for entity in target.values()] 207 | # print("Prediction", predicted_output) 208 | # print("Target", target) 209 | keys = set(key for entity in target + predicted_output for key in entity.keys()) 210 | # print("keys:", keys) 211 | keys = list(keys) 212 | 213 | # target = remove_duplicates_and_postprocess(target) 214 | # predicted_output = remove_duplicates_and_postprocess(predicted_output) 215 | 216 | # Pad the predicted entities or ground-truth with dummy entities 217 | # to ensure that the number of entities is the same 218 | if target_set_size > predicted_output_set_size: 219 | predicted_output += [ 220 | {} for _ in range(target_set_size - predicted_output_set_size) 221 | ] 222 | elif predicted_output_set_size > target_set_size: 223 | target += [{} for _ in range(predicted_output_set_size - target_set_size)] 224 | 225 | # print("Prediction", predicted_output) 226 | # print("Target", target) 227 | 228 | def get_key_index(key, keys): 229 | for i, k in enumerate(keys): 230 | if key == k: 231 | return i 232 | raise ValueError(f"The key {key} does not exists in the key list") 233 | 234 | # Create property key tensors 235 | def create_pk_tensor(entity, keys): 236 | tensor = [0] * len(keys) 237 | for key in entity.keys(): 238 | tensor[get_key_index(key, keys)] = 1 239 | return torch.tensor(tensor) 240 | 241 | def create_pv_list(entity, keys): 242 | lst = [""] * len(keys) 243 | for key, value in entity.items(): 244 | if not isinstance(value, str): 245 | if isinstance(value, list): 246 | value = " ".join(value) 247 | else: 248 | value = str(value) 249 | lst[get_key_index(key, keys)] = value 250 | 251 | return lst 252 | 253 | def jaccard_similarity(tokens_target, tokens_pred): 254 | # Compute the Jaccard similarity (intersection of the token set over the union) 255 | intersection_size = len(tokens_target.intersection(tokens_pred)) 256 | union_size = len(tokens_target.union(tokens_pred)) 257 | return intersection_size / union_size 258 | 259 | target_entities = [ 260 | {"pk": create_pk_tensor(e, keys), "pv": create_pv_list(e, keys)} for e in target 261 | ] 262 | predicted_entities = [ 263 | {"pk": create_pk_tensor(e, keys), "pv": create_pv_list(e, keys)} 264 | for e in predicted_output 265 | ] 266 | 267 | # Pad the predicted entities or ground-truth with dummy entities 268 | # to ensure that the number of entities is the same 269 | # if target_set_size > predicted_output_set_size: 270 | # predicted_entities += [ 271 | # {"pk": torch.zeros_like(target_entities[0]["pk"]), "pv": [""]} 272 | # for _ in range(target_set_size - predicted_output_set_size) 273 | # ] 274 | # elif predicted_output_set_size > target_set_size: 275 | # target_entities += [ 276 | # {"pk": torch.zeros_like(predicted_entities[0]["pk"]), "pv": [""]} 277 | # for _ in range(predicted_output_set_size - target_set_size) 278 | # ] 279 | 280 | # assume a weight of 1 for each property apart from name 281 | try: 282 | weights = torch.zeros_like(target_entities[0]["pk"]) 283 | except IndexError: 284 | print(target_entities) 285 | print(target) 286 | print(predicted_output) 287 | # try: 288 | # if measures == "ExactName": 289 | # weights[get_key_index("name", keys)] = 1 290 | # elif measures == "ApproxName": 291 | # weights[get_key_index("name", keys)] = 1 292 | # elif measures == "MultiProp": 293 | # weights[get_key_index("name", keys)] = 2 294 | # for index, key in enumerate(keys): 295 | # if key != "name": 296 | # weights[index] = 1 297 | # except ValueError: 298 | 299 | if measures == "ExactName": 300 | weights[get_key_index("entity name", keys)] = 1 301 | elif measures == "ApproxName": 302 | weights[get_key_index("entity name", keys)] = 1 303 | elif measures == "MultiProp": 304 | weights[get_key_index("entity name", keys)] = 11 305 | for index, key in enumerate(keys): 306 | if key != "entity name": 307 | weights[index] = 1 308 | # pv_distances_aon = compute_distances( 309 | # target_entities, predicted_entities, ["Pv-prop-ACC"], weights 310 | # ) 311 | # pv_distances_token = compute_distances( 312 | # target_entities, predicted_entities, ["Pv-token-ACC"], weights 313 | # ) 314 | # pv_distances_aon_unweighted = compute_distances( 315 | # target_entities, predicted_entities, ["Pv-prop-ACC"], torch.ones_like(weights) 316 | # ) 317 | # pv_distances_token_unweighted = compute_distances( 318 | # target_entities, predicted_entities, ["Pv-token-ACC"], torch.ones_like(weights) 319 | # ) 320 | # pk_distances_acc = compute_distances( 321 | # target_entities, predicted_entities, ["Pk-ACC"], torch.ones_like(weights) 322 | # ) 323 | 324 | # ( 325 | # pv_distances_token_loss, 326 | # permutation_target, 327 | # permutation_prediction, 328 | # ) = bipartite_matching(pv_distances_token) 329 | # pv_distances_aon_loss, _, _ = bipartite_matching(pv_distances_aon) 330 | # pk_distances_acc_loss, _, _ = bipartite_matching(pk_distances_acc) 331 | # pv_distances_token_unweighted_loss, _, _ = bipartite_matching( 332 | # pv_distances_token_unweighted 333 | # ) 334 | # pv_distances_aon_unweighted_loss, _, _ = bipartite_matching( 335 | # pv_distances_aon_unweighted 336 | # ) 337 | 338 | if measures == "ExactName": 339 | entity_distance = compute_distances( 340 | target_entities, predicted_entities, "ExactName", weights 341 | ) 342 | elif measures == "ApproxName": 343 | entity_distance = compute_distances( 344 | target_entities, predicted_entities, "ApproxName", weights 345 | ) 346 | elif measures == "MultiProp": 347 | entity_distance = compute_distances( 348 | target_entities, predicted_entities, "MultiProp", weights 349 | ) 350 | 351 | # if normalization == "Max": 352 | # ( 353 | # permutation_target, 354 | # permutation_prediction, 355 | # ) = bipartite_matching(entity_distance) 356 | # elif normalization == "Precision": 357 | # ( 358 | # permutation_target, 359 | # permutation_prediction, 360 | # ) = bipartite_matching(entity_distance) 361 | # elif normalization == "Recall": 362 | ( 363 | permutation_target, 364 | permutation_prediction, 365 | ) = bipartite_matching(entity_distance) 366 | 367 | # Only establish matches that have a distance below threshold 368 | # the threshold and weight_pks is calibrated such that it does not suffice to have 369 | # a matched type property without a matched name, if the entity only contains name and type (which is often the case) 370 | established_entity_matches = [] 371 | established_entity_matches_tensor = [] 372 | # print("permutation_target", permutation_target) 373 | # print("permutation_prediction", permutation_prediction) 374 | # print("target", target) 375 | # print("predicted_output", predicted_output) 376 | for predicted_idx, target_idx in enumerate(permutation_target): 377 | # if entity_distance[target_idx, predicted_idx] <= establish_threshold: 378 | # established_entity_matches.append( 379 | # (target[target_idx], predicted_output[predicted_idx]) 380 | # ) 381 | # established_entity_matches.append( 382 | # (target[target_idx], predicted_output[predicted_idx]) 383 | # ) 384 | established_entity_matches.append( 385 | (target[target_idx], predicted_output[predicted_idx]) 386 | ) 387 | established_entity_matches_tensor.append( 388 | (target_entities[target_idx], predicted_entities[predicted_idx]) 389 | ) 390 | 391 | def property_value_token_distance_acc(e1, e2, weights): 392 | v1, v2 = e1["pv"], e2["pv"] 393 | 394 | # Split each property value into tokens (words) 395 | tokens_v1 = [set(value.lower().split()) for value in v1] 396 | tokens_v2 = [set(value.lower().split()) for value in v2] 397 | 398 | jaccard_similarities = [] 399 | _weights = weights.clone() 400 | for i, (t1, t2) in enumerate(zip(tokens_v1, tokens_v2)): 401 | # Compute the Jaccard similarity for the token sets 402 | intersection_size = len(t1.intersection(t2)) 403 | union_size = len(t1.union(t2)) 404 | if union_size == 0: 405 | jaccard_sim = 0.0 406 | _weights[i] = 0.0 407 | else: 408 | jaccard_sim = intersection_size / union_size 409 | jaccard_similarities.append(jaccard_sim) 410 | similarity = ( 411 | torch.tensor(jaccard_similarities) * _weights 412 | ).sum() / _weights.sum() 413 | return similarity.item() 414 | 415 | # target_entities = [ 416 | # {"pk": create_pk_tensor(e, keys), "pv": create_pv_list(e, keys)} for e in target 417 | # ] 418 | # predicted_entities = [ 419 | # {"pk": create_pk_tensor(e, keys), "pv": create_pv_list(e, keys)} 420 | # for e in predicted_output 421 | # ] 422 | 423 | # assume a weight of 1 for each property 424 | weights = torch.ones_like(target_entities[0]["pk"]) 425 | 426 | all_similarities = [] 427 | for t, p in established_entity_matches_tensor: 428 | similarity = property_value_token_distance_acc(t, p, weights) 429 | all_similarities.append(similarity) 430 | 431 | if normalization == "Max": 432 | normalized_similarity = sum(all_similarities) / max( 433 | target_set_size, predicted_output_set_size 434 | ) 435 | elif normalization == "Precision": 436 | normalized_similarity = sum(all_similarities) / predicted_output_set_size 437 | elif normalization == "Recall": 438 | normalized_similarity = sum(all_similarities) / target_set_size 439 | 440 | per_property_acc_token = {key: 0.0 for key in keys} 441 | per_property_acc_aon = {key: 0.0 for key in keys} 442 | target_key_occurance = {key: 0.0 for key in keys} 443 | pred_key_occurance = {key: 0.0 for key in keys} 444 | key_matches = {key: 0.0 for key in keys} 445 | for e_target, e_pred in established_entity_matches: 446 | for key in keys: 447 | target_key_occurance[key] += key in e_target.keys() 448 | pred_key_occurance[key] += key in e_pred.keys() 449 | if key in e_target.keys() and key in e_pred.keys(): 450 | key_matches[key] += 1.0 451 | tokens_target = set(e_target[key].lower().split()) 452 | tokens_pred = set(e_pred[key].lower().split()) 453 | 454 | jaccard_sim = jaccard_similarity(tokens_target, tokens_pred) 455 | per_property_acc_token[key] += jaccard_sim 456 | 457 | per_property_acc_aon[key] += ( 458 | e_target[key].lower() == e_pred[key].lower() 459 | ) 460 | 461 | # # calculate per property similarity 462 | # prop_similarities = {} 463 | # 464 | # for pk in keys: 465 | # prop_similarities[pk] = {} 466 | # all_sim = [] 467 | # weights = torch.zeros_like(target_entities[0]["pk"]) 468 | # weights[get_key_index(pk, keys)] = 1 469 | # for (t, p) in established_entity_matches_tensor: 470 | # similarity = property_value_token_distance_acc(t, p, weights) 471 | # all_sim.append(similarity) 472 | # prop_similarities[pk]["Max"] = sum(all_sim) / max(target_set_size, predicted_output_set_size) 473 | # prop_similarities[pk]["Precision"] = sum(all_sim) / predicted_output_set_size 474 | # prop_similarities[pk]["Recall"] = sum(all_sim) / target_set_size 475 | # if str(prop_similarities[pk]["Max"]) == "nan": 476 | # prop_similarities[pk]["Max"] = 0 477 | # if str(prop_similarities[pk]["Precision"]) == "nan": 478 | # prop_similarities[pk]["Precision"] = 0 479 | # if str(prop_similarities[pk]["Recall"]) == "nan": 480 | # prop_similarities[pk]["Recall"] = 0 481 | 482 | counts_nested = { 483 | "per_property_acc_token": per_property_acc_token, 484 | "per_property_acc_aon": per_property_acc_aon, 485 | "target_key_occurance": target_key_occurance, 486 | "pred_key_occurance": pred_key_occurance, 487 | "key_matches": key_matches, 488 | } 489 | 490 | counts = { 491 | "established_entity_matches": len(established_entity_matches), 492 | "predicted_output_entities_no_dup": len(predicted_output), 493 | "target_entities_no_dup": len(target), 494 | } 495 | 496 | # bipartite_matching_metrics = { 497 | # "normalized_similarity": normalized_similarity, 498 | # # "pv_distances_token_loss": pv_distances_token_loss, 499 | # # "pv_distances_aon_loss": pv_distances_aon_loss, 500 | # # "pk_distances_acc_loss": pk_distances_acc_loss, 501 | # # "pv_distances_token_unweighted_loss": pv_distances_token_unweighted_loss, 502 | # # "pv_distances_aon_unweighted_loss": pv_distances_aon_unweighted_loss, 503 | # } 504 | final_metrics = { 505 | "normalized_similarity": normalized_similarity, 506 | } 507 | 508 | return final_metrics, counts, counts_nested 509 | 510 | 511 | if __name__ == "__main__": 512 | # Sample usage 513 | 514 | dummy_entity = { 515 | "t": torch.tensor([0.0, 0.0]), 516 | "pk": torch.tensor([0.0, 0.0, 0.0]), 517 | "pv": ["dummy", "dummy", "dummy"], 518 | } 519 | 520 | # In this demo, we assume N=3 (max entity), M=2 (num entity type), K=3 (num property keys) 521 | # t (ground-truth) is a one-hot vector. pk (ground-truth) is a multi-hot vector. 522 | # t (prediction) is post-softmax. pk (prediction) is post-sigmoid. 523 | gt = [ 524 | { 525 | "t": torch.tensor([1.0, 0.0]), 526 | "pk": torch.tensor([1.0, 1.0, 1.0]), 527 | "pv": ["XX apple", "round", "big"], 528 | }, 529 | { 530 | "t": torch.tensor([0.0, 1.0]), 531 | "pk": torch.tensor([1.0, 0.0, 1.0]), 532 | "pv": ["YY banana", "long", "big"], 533 | }, 534 | { 535 | "t": torch.tensor([1.0, 0.0]), 536 | "pk": torch.tensor([0.0, 1.0, 1.0]), 537 | "pv": ["ZZ grape", "round", "small"], 538 | }, 539 | ] 540 | pred1 = [ 541 | { 542 | "t": torch.tensor([1.0, 0.0]), 543 | "pk": torch.tensor([0.0, 1.0, 1.0]), 544 | "pv": ["ZZ grape", "round", "small"], 545 | }, 546 | { 547 | "t": torch.tensor([1.0, 0.0]), 548 | "pk": torch.tensor([1.0, 1.0, 1.0]), 549 | "pv": ["XX apple", "round", "big"], 550 | }, 551 | { 552 | "t": torch.tensor([0.0, 1.0]), 553 | "pk": torch.tensor([1.0, 0.0, 1.0]), 554 | "pv": ["YY banana", "long", "big"], 555 | }, 556 | ] 557 | 558 | pred2 = [ 559 | { 560 | "t": torch.tensor([0.999, 0.001]), 561 | "pk": torch.tensor([1.0, 1.0, 1.0]), 562 | "pv": ["YY apple", "round", "big"], 563 | }, 564 | { 565 | "t": torch.tensor([0.001, 0.999]), 566 | "pk": torch.tensor([1.0, 0.1, 1.0]), 567 | "pv": ["XX banana", "long", "small"], 568 | }, 569 | { 570 | "t": torch.tensor([0.8, 0.2]), 571 | "pk": torch.tensor([0.4, 0.9, 0.4]), 572 | "pv": ["ZZ grape", "round", "small"], 573 | }, 574 | ] 575 | 576 | pred3 = [ 577 | { 578 | "t": torch.tensor([0.7, 0.3]), 579 | "pk": torch.tensor([0.2, 0.8, 0.8]), 580 | "pv": ["XX peach", "round", "very small"], 581 | }, 582 | dummy_entity, 583 | dummy_entity, 584 | ] 585 | 586 | pred_list = [pred1, pred2, pred3] 587 | for i in range(len(pred_list)): 588 | print(f"Compare GT with Pred{i + 1}:") 589 | # distances = compute_distances(gt, pred_list[i], measures=["E-CE"]) 590 | # distances = compute_distances(gt, pred_list[i], measures=["Pk-BCE"]) 591 | distances = compute_distances(gt, pred_list[i], measures=["E-ACC"]) 592 | # distances = compute_distances(gt, pred_list[i], measures=["Pk-ACC"]) 593 | # distances = compute_distances(gt, pred_list[i], measures=["Pv-prop-ACC"]) 594 | # distances = compute_distances(gt, pred_list[i], measures=["Pv-token-ACC"]) 595 | 596 | ( 597 | optimal_metric_loss, 598 | permutation_ground_truth, 599 | permutation_prediction, 600 | ) = bipartite_matching(distances) 601 | print("optimal_metric_loss (CE loss or 1 - ACC):", optimal_metric_loss) 602 | print("permutation_ground_truth:", permutation_ground_truth) 603 | print("permutation_prediction:", permutation_prediction) 604 | print("-----") 605 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/Structured-Entity-Extraction/8e3d0fbb009874b5ba35a450de1ff37995b8102e/model/__init__.py -------------------------------------------------------------------------------- /model/t5_with_t5decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from transformers import T5Config, T5ForConditionalGeneration 4 | 5 | 6 | class T5_with_T5Decoder(nn.Module): 7 | def __init__(self, pretrained_t5_name, tokenizer, pre_train=True): 8 | super().__init__() 9 | 10 | if pre_train: 11 | print("Using Pre-trained T5 model") 12 | self.t5_model = T5ForConditionalGeneration.from_pretrained( 13 | pretrained_t5_name 14 | ) 15 | else: 16 | print("Using Randomly Initialized T5 model") 17 | self.t5_model = T5ForConditionalGeneration(T5Config()) 18 | print("Window Size:", self.t5_model.config.d_model) 19 | self.tokenizer = tokenizer 20 | self.d_model = self.t5_model.config.d_model 21 | 22 | def forward(self, input_ids=None, attention_mask=None): 23 | 24 | original_text_embeddings = self.t5_model.shared(input_ids) 25 | # Encode the input text 26 | encoder_outputs = self.t5_model.encoder( 27 | input_ids=input_ids, 28 | attention_mask=attention_mask, 29 | ) 30 | 31 | return original_text_embeddings, encoder_outputs 32 | 33 | def decode_at(self, encoder_outputs, decoder_input_ids, target_sequence_length): 34 | all_logits = [] 35 | past_key_values = None 36 | 37 | for i in range(target_sequence_length): 38 | outputs = self.t5_model( 39 | input_ids=None, 40 | attention_mask=None, 41 | decoder_input_ids=decoder_input_ids, 42 | encoder_outputs=encoder_outputs, 43 | past_key_values=past_key_values, 44 | use_cache=True, 45 | return_dict=True, 46 | ) 47 | 48 | next_token_logits = outputs.logits[:, -1] 49 | next_tokens = next_token_logits.argmax(-1, keepdim=True) 50 | all_logits.append(next_token_logits.unsqueeze(1)) 51 | 52 | # Update decoder_input_ids for the next iteration 53 | decoder_input_ids = torch.cat([decoder_input_ids, next_tokens], dim=-1) 54 | past_key_values = ( 55 | outputs.past_key_values 56 | ) # Store past key values for the next iteration 57 | 58 | return torch.cat(all_logits, dim=1) 59 | -------------------------------------------------------------------------------- /model/t5_with_t5decoder_emb.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from transformers import T5Config, T5ForConditionalGeneration 4 | 5 | 6 | class T5_with_T5Decoder(nn.Module): 7 | def __init__(self, pretrained_t5_name, tokenizer, pre_train=True): 8 | super().__init__() 9 | 10 | if pre_train: 11 | print("Using Pre-trained T5 model") 12 | self.t5_model = T5ForConditionalGeneration.from_pretrained( 13 | pretrained_t5_name 14 | ) 15 | else: 16 | print("Using Randomly Initialized T5 model") 17 | self.t5_model = T5ForConditionalGeneration(T5Config()) 18 | print("Window Size:", self.t5_model.config.d_model) 19 | self.tokenizer = tokenizer 20 | self.d_model = self.t5_model.config.d_model 21 | 22 | def forward( 23 | self, 24 | input_ids=None, 25 | attention_mask=None, 26 | decoder_input_ids=None, 27 | decoder_attention_mask=None, 28 | output_hidden_states=False, 29 | return_dict=True, 30 | ): 31 | 32 | original_text_embeddings = self.t5_model.shared(input_ids) 33 | # Encode the input text 34 | encoder_outputs = self.t5_model.encoder( 35 | input_ids=input_ids, 36 | attention_mask=attention_mask, 37 | # return_dict=True, 38 | # output_hidden_states=output_hidden_states 39 | ) 40 | 41 | return original_text_embeddings, encoder_outputs 42 | 43 | def decode_at_emb(self, encoder_outputs, target_sequence_length=20): 44 | device = encoder_outputs.last_hidden_state.device 45 | num_position = encoder_outputs.last_hidden_state.size(0) 46 | 47 | # Initialize start tokens 48 | start_token_id = ( 49 | self.tokenizer.bos_token_id 50 | if self.tokenizer.bos_token_id is not None 51 | else self.tokenizer.pad_token_id 52 | ) 53 | start_tokens = torch.full( 54 | (num_position, 1), 55 | start_token_id, 56 | dtype=torch.long, 57 | device=device, 58 | ) 59 | 60 | # Get the embeddings of the start tokens 61 | decoder_inputs_embeds = self.t5_model.get_input_embeddings()(start_tokens) 62 | 63 | # Initialize the decoder's attention mask with a single 1 for the start token 64 | decoder_attention_mask = torch.ones( 65 | (num_position, 1), dtype=torch.long, device=device 66 | ) 67 | 68 | past_key_values = None 69 | return_embeds = None 70 | 71 | for _ in range(target_sequence_length): 72 | outputs = self.t5_model( 73 | encoder_outputs=encoder_outputs, 74 | decoder_inputs_embeds=decoder_inputs_embeds, 75 | decoder_attention_mask=decoder_attention_mask, 76 | past_key_values=past_key_values, 77 | use_cache=True, 78 | output_hidden_states=True, 79 | ) 80 | 81 | # Extract the last hidden state (embedding) of the last token 82 | last_embedding = outputs.decoder_hidden_states[-1][ 83 | :, -1:, : 84 | ] # (b, 1, d_model) 85 | return_embeds = ( 86 | last_embedding 87 | if return_embeds is None 88 | else torch.cat([return_embeds, last_embedding], dim=1) 89 | ) 90 | 91 | # if use past_key_values, only need to input the last decoder_inputs_embs 92 | decoder_inputs_embeds = last_embedding 93 | 94 | # Update the decoder's attention mask 95 | decoder_attention_mask = torch.cat( 96 | [ 97 | decoder_attention_mask, 98 | torch.ones((num_position, 1), dtype=torch.long, device=device), 99 | ], 100 | dim=1, 101 | ) 102 | 103 | # Store past key values for the next iteration 104 | past_key_values = outputs.past_key_values 105 | 106 | return return_embeds 107 | -------------------------------------------------------------------------------- /model/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/Structured-Entity-Extraction/8e3d0fbb009874b5ba35a450de1ff37995b8102e/model/utils/__init__.py -------------------------------------------------------------------------------- /model/utils/dist_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for distributed training. 3 | """ 4 | 5 | import io 6 | import os 7 | import socket 8 | 9 | import blobfile as bf 10 | import torch as th 11 | import torch.distributed as dist 12 | 13 | # Change this to reflect your cluster layout. 14 | 15 | 16 | def setup_dist(): 17 | """ 18 | Setup a distributed process group. 19 | """ 20 | if dist.is_initialized(): 21 | return 22 | 23 | backend = "gloo" if not th.cuda.is_available() else "nccl" 24 | 25 | if backend == "gloo": 26 | hostname = "localhost" 27 | else: 28 | hostname = socket.gethostbyname(socket.getfqdn()) 29 | 30 | if os.environ.get("LOCAL_RANK") is None: 31 | os.environ["MASTER_ADDR"] = hostname 32 | os.environ["RANK"] = str(0) 33 | os.environ["WORLD_SIZE"] = str(1) 34 | port = _find_free_port() 35 | os.environ["MASTER_PORT"] = str(port) 36 | os.environ["LOCAL_RANK"] = str(0) 37 | 38 | dist.init_process_group(backend=backend, init_method="env://") 39 | 40 | if th.cuda.is_available(): # This clears remaining caches in GPU 0 41 | th.cuda.set_device(dev()) 42 | th.cuda.empty_cache() 43 | 44 | 45 | def dev(): 46 | """ 47 | Get the device to use for torch.distributed. 48 | """ 49 | if th.cuda.is_available(): 50 | return th.device(f"cuda:{os.environ['LOCAL_RANK']}") 51 | return th.device("cpu") 52 | 53 | 54 | def load_state_dict(path, **kwargs): 55 | """ 56 | Load a PyTorch file. 57 | """ 58 | # if int(os.environ['LOCAL_RANK']) == 0: 59 | with bf.BlobFile(path, "rb") as f: 60 | data = f.read() 61 | return th.load(io.BytesIO(data), **kwargs) 62 | 63 | 64 | def sync_params(params): 65 | """ 66 | Synchronize a sequence of Tensors across ranks from rank 0. 67 | """ 68 | for p in params: 69 | with th.no_grad(): 70 | dist.broadcast(p, 0) 71 | 72 | 73 | def _find_free_port(): 74 | try: 75 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 76 | s.bind(("127.0.0.1", 0)) # Bind to the local interface only 77 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 78 | return s.getsockname()[1] 79 | finally: 80 | s.close() 81 | -------------------------------------------------------------------------------- /model/utils/fp16_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers to train with 16-bit precision. 3 | """ 4 | 5 | import torch.nn as nn 6 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors 7 | 8 | 9 | def convert_module_to_f16(l): 10 | """ 11 | Convert primitive modules to float16. 12 | """ 13 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 14 | l.weight.data = l.weight.data.half() 15 | l.bias.data = l.bias.data.half() 16 | 17 | 18 | def convert_module_to_f32(l): 19 | """ 20 | Convert primitive modules to float32, undoing convert_module_to_f16(). 21 | """ 22 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 23 | l.weight.data = l.weight.data.float() 24 | l.bias.data = l.bias.data.float() 25 | 26 | 27 | def make_master_params(model_params): 28 | """ 29 | Copy model parameters into a (differently-shaped) list of full-precision 30 | parameters. 31 | """ 32 | master_params = _flatten_dense_tensors( 33 | [param.detach().float() for param in model_params] 34 | ) 35 | master_params = nn.Parameter(master_params) 36 | master_params.requires_grad = True 37 | return [master_params] 38 | 39 | 40 | def model_grads_to_master_grads(model_params, master_params): 41 | """ 42 | Copy the gradients from the model parameters into the master parameters 43 | from make_master_params(). 44 | """ 45 | master_params[0].grad = _flatten_dense_tensors( 46 | [param.grad.data.detach().float() for param in model_params] 47 | ) 48 | 49 | 50 | def master_params_to_model_params(model_params, master_params): 51 | """ 52 | Copy the master parameter data back into the model parameters. 53 | """ 54 | # Without copying to a list, if a generator is passed, this will 55 | # silently not copy any parameters. 56 | model_params = list(model_params) 57 | 58 | for param, master_param in zip( 59 | model_params, unflatten_master_params(model_params, master_params) 60 | ): 61 | param.detach().copy_(master_param) 62 | 63 | 64 | def unflatten_master_params(model_params, master_params): 65 | """ 66 | Unflatten the master parameters to look like model_params. 67 | """ 68 | return _unflatten_dense_tensors(master_params[0].detach(), model_params) 69 | 70 | 71 | def zero_grad(model_params): 72 | for param in model_params: 73 | # Taken from https://pytorch.org/docs/stable/_modules/torch/ 74 | # optim/optimizer.html#Optimizer.add_param_group 75 | if param.grad is not None: 76 | param.grad.detach_() 77 | param.grad.zero_() 78 | -------------------------------------------------------------------------------- /model/utils/logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | Logger copied from OpenAI baselines to avoid extra RL-based dependencies: 3 | https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/logger.py 4 | """ 5 | 6 | import datetime 7 | import json 8 | import os 9 | import os.path as osp 10 | # import shutil 11 | import sys 12 | import tempfile 13 | import time 14 | import warnings 15 | from collections import defaultdict 16 | from contextlib import contextmanager 17 | 18 | import wandb 19 | 20 | DEBUG = 10 21 | INFO = 20 22 | WARN = 30 23 | ERROR = 40 24 | 25 | DISABLED = 50 26 | 27 | 28 | class KVWriter(object): 29 | def writekvs(self, kvs): 30 | raise NotImplementedError 31 | 32 | 33 | class SeqWriter(object): 34 | def writeseq(self, seq): 35 | raise NotImplementedError 36 | 37 | 38 | class HumanOutputFormat(KVWriter, SeqWriter): 39 | def __init__(self, filename_or_file): 40 | if isinstance(filename_or_file, str): 41 | self.file = open(filename_or_file, "wt") 42 | self.own_file = True 43 | else: 44 | assert hasattr(filename_or_file, "read"), ( 45 | "expected file or str, got %s" % filename_or_file 46 | ) 47 | self.file = filename_or_file 48 | self.own_file = False 49 | 50 | def writekvs(self, kvs): 51 | # Create strings for printing 52 | key2str = {} 53 | for key, val in sorted(kvs.items()): 54 | if hasattr(val, "__float__"): 55 | valstr = "%-8.3g" % val 56 | else: 57 | valstr = str(val) 58 | key2str[self._truncate(key)] = self._truncate(valstr) 59 | 60 | # Find max widths 61 | if len(key2str) == 0: 62 | print("WARNING: tried to write empty key-value dict") 63 | return 64 | else: 65 | keywidth = max(map(len, key2str.keys())) 66 | valwidth = max(map(len, key2str.values())) 67 | 68 | # Write out the data 69 | dashes = "-" * (keywidth + valwidth + 7) 70 | lines = [dashes] 71 | for key, val in sorted(key2str.items(), key=lambda kv: kv[0].lower()): 72 | lines.append( 73 | "| %s%s | %s%s |" 74 | % (key, " " * (keywidth - len(key)), val, " " * (valwidth - len(val))) 75 | ) 76 | lines.append(dashes) 77 | self.file.write("\n".join(lines) + "\n") 78 | 79 | # Flush the output to the file 80 | self.file.flush() 81 | 82 | def _truncate(self, s): 83 | maxlen = 30 84 | return s[: maxlen - 3] + "..." if len(s) > maxlen else s 85 | 86 | def writeseq(self, seq): 87 | seq = list(seq) 88 | for i, elem in enumerate(seq): 89 | self.file.write(elem) 90 | if i < len(seq) - 1: # add space unless this is the last one 91 | self.file.write(" ") 92 | self.file.write("\n") 93 | self.file.flush() 94 | 95 | def close(self): 96 | if self.own_file: 97 | self.file.close() 98 | 99 | 100 | class JSONOutputFormat(KVWriter): 101 | def __init__(self, filename): 102 | self.file = open(filename, "wt") 103 | 104 | def writekvs(self, kvs): 105 | for k, v in sorted(kvs.items()): 106 | if hasattr(v, "dtype"): 107 | kvs[k] = float(v) 108 | self.file.write(json.dumps(kvs) + "\n") 109 | self.file.flush() 110 | 111 | def close(self): 112 | self.file.close() 113 | 114 | 115 | class CSVOutputFormat(KVWriter): 116 | def __init__(self, filename): 117 | self.file = open(filename, "w+t") 118 | self.keys = [] 119 | self.sep = "," 120 | 121 | def writekvs(self, kvs): 122 | # Add our current row to the history 123 | extra_keys = list(kvs.keys() - self.keys) 124 | extra_keys.sort() 125 | if extra_keys: 126 | self.keys.extend(extra_keys) 127 | self.file.seek(0) 128 | lines = self.file.readlines() 129 | self.file.seek(0) 130 | for i, k in enumerate(self.keys): 131 | if i > 0: 132 | self.file.write(",") 133 | self.file.write(k) 134 | self.file.write("\n") 135 | for line in lines[1:]: 136 | self.file.write(line[:-1]) 137 | self.file.write(self.sep * len(extra_keys)) 138 | self.file.write("\n") 139 | for i, k in enumerate(self.keys): 140 | if i > 0: 141 | self.file.write(",") 142 | v = kvs.get(k) 143 | if v is not None: 144 | self.file.write(str(v)) 145 | self.file.write("\n") 146 | self.file.flush() 147 | 148 | def close(self): 149 | self.file.close() 150 | 151 | 152 | class TensorBoardOutputFormat(KVWriter): 153 | """ 154 | Dumps key/value pairs into TensorBoard's numeric format. 155 | """ 156 | 157 | def __init__(self, dir): 158 | os.makedirs(dir, exist_ok=True) 159 | self.dir = dir 160 | self.step = 1 161 | prefix = "events" 162 | path = osp.join(osp.abspath(dir), prefix) 163 | import tensorflow as tf 164 | from tensorflow.core.util import event_pb2 165 | from tensorflow.python import pywrap_tensorflow 166 | from tensorflow.python.util import compat 167 | 168 | self.tf = tf 169 | self.event_pb2 = event_pb2 170 | self.pywrap_tensorflow = pywrap_tensorflow 171 | self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path)) 172 | 173 | def writekvs(self, kvs): 174 | def summary_val(k, v): 175 | kwargs = {"tag": k, "simple_value": float(v)} 176 | return self.tf.Summary.Value(**kwargs) 177 | 178 | summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()]) 179 | event = self.event_pb2.Event(wall_time=time.time(), summary=summary) 180 | event.step = ( 181 | self.step 182 | ) # is there any reason why you'd want to specify the step? 183 | self.writer.WriteEvent(event) 184 | self.writer.Flush() 185 | self.step += 1 186 | 187 | def close(self): 188 | if self.writer: 189 | self.writer.Close() 190 | self.writer = None 191 | 192 | 193 | def make_output_format(format, ev_dir, log_suffix=""): 194 | os.makedirs(ev_dir, exist_ok=True) 195 | if format == "stdout": 196 | return HumanOutputFormat(sys.stdout) 197 | elif format == "log": 198 | return HumanOutputFormat(osp.join(ev_dir, "log%s.txt" % log_suffix)) 199 | elif format == "json": 200 | return JSONOutputFormat(osp.join(ev_dir, "progress%s.json" % log_suffix)) 201 | elif format == "csv": 202 | return CSVOutputFormat(osp.join(ev_dir, "progress%s.csv" % log_suffix)) 203 | elif format == "tensorboard": 204 | return TensorBoardOutputFormat(osp.join(ev_dir, "tb%s" % log_suffix)) 205 | else: 206 | raise ValueError("Unknown format specified: %s" % (format,)) 207 | 208 | 209 | # ================================================================ 210 | # API 211 | # ================================================================ 212 | 213 | 214 | def logkv(key, val): 215 | """ 216 | Log a value of some diagnostic 217 | Call this once for each diagnostic quantity, each iteration 218 | If called many times, last value will be used. 219 | """ 220 | get_current().logkv(key, val) 221 | 222 | 223 | def logkv_mean(key, val): 224 | """ 225 | The same as logkv(), but if called many times, values averaged. 226 | """ 227 | get_current().logkv_mean(key, val) 228 | 229 | 230 | def logkvs(d): 231 | """ 232 | Log a dictionary of key-value pairs 233 | """ 234 | for k, v in d.items(): 235 | logkv(k, v) 236 | 237 | 238 | def dumpkvs(): 239 | """ 240 | Write all of the diagnostics from the current iteration 241 | """ 242 | return get_current().dumpkvs() 243 | 244 | 245 | def getkvs(): 246 | return get_current().name2val 247 | 248 | 249 | def log(*args, level=INFO): 250 | """ 251 | Write the sequence of args, with no separators, 252 | to the console and output files (if you've configured an output file). 253 | """ 254 | get_current().log(*args, level=level) 255 | 256 | 257 | def debug(*args): 258 | log(*args, level=DEBUG) 259 | 260 | 261 | def info(*args): 262 | log(*args, level=INFO) 263 | 264 | 265 | def warn(*args): 266 | log(*args, level=WARN) 267 | 268 | 269 | def error(*args): 270 | log(*args, level=ERROR) 271 | 272 | 273 | def set_level(level): 274 | """ 275 | Set logging threshold on current logger. 276 | """ 277 | get_current().set_level(level) 278 | 279 | 280 | def set_comm(comm): 281 | get_current().set_comm(comm) 282 | 283 | 284 | def get_dir(): 285 | """ 286 | Get directory that log files are being written to. 287 | will be None if there is no output directory (i.e., if you didn't call start) 288 | """ 289 | return get_current().get_dir() 290 | 291 | 292 | record_tabular = logkv 293 | dump_tabular = dumpkvs 294 | 295 | 296 | @contextmanager 297 | def profile_kv(scopename): 298 | logkey = "wait_" + scopename 299 | tstart = time.time() 300 | try: 301 | yield 302 | finally: 303 | get_current().name2val[logkey] += time.time() - tstart 304 | 305 | 306 | def profile(n): 307 | """ 308 | Usage: 309 | @profile("my_func") 310 | def my_func(): code 311 | """ 312 | 313 | def decorator_with_name(func): 314 | def func_wrapper(*args, **kwargs): 315 | with profile_kv(n): 316 | return func(*args, **kwargs) 317 | 318 | return func_wrapper 319 | 320 | return decorator_with_name 321 | 322 | 323 | # ================================================================ 324 | # Backend 325 | # ================================================================ 326 | 327 | 328 | def get_current(): 329 | if Logger.CURRENT is None: 330 | _configure_default_logger() 331 | 332 | return Logger.CURRENT 333 | 334 | 335 | class Logger(object): 336 | DEFAULT = None # A logger with no output files. (See right below class definition) 337 | # So that you can still log to the terminal without setting up any output files 338 | CURRENT = None # Current logger being used by the free functions above 339 | 340 | def __init__(self, dir, output_formats, comm=None): 341 | self.name2val = defaultdict(float) # values this iteration 342 | self.name2cnt = defaultdict(int) 343 | self.level = INFO 344 | self.dir = dir 345 | self.output_formats = output_formats 346 | self.comm = comm 347 | 348 | # Logging API, forwarded 349 | # ---------------------------------------- 350 | def logkv(self, key, val): 351 | self.name2val[key] = val 352 | 353 | def logkv_mean(self, key, val): 354 | oldval, cnt = self.name2val[key], self.name2cnt[key] 355 | self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1) 356 | self.name2cnt[key] = cnt + 1 357 | 358 | def dumpkvs(self, prefix=None): 359 | if self.comm is None: 360 | d = self.name2val 361 | else: 362 | d = mpi_weighted_mean( 363 | self.comm, 364 | { 365 | name: (val, self.name2cnt.get(name, 1)) 366 | for (name, val) in self.name2val.items() 367 | }, 368 | ) 369 | if self.comm.rank != 0: 370 | d["dummy"] = 1 # so we don't get a warning about empty dict 371 | # LISA 372 | out = d.copy() # Return the dict for unit testing purposes 373 | if int(os.environ["LOCAL_RANK"]) == 0: 374 | wandb.log({**d}) 375 | for fmt in self.output_formats: 376 | if isinstance(fmt, KVWriter): 377 | fmt.writekvs(d) 378 | self.name2val.clear() 379 | self.name2cnt.clear() 380 | return out 381 | 382 | def log(self, *args, level=INFO): 383 | if self.level <= level: 384 | self._do_log(args) 385 | 386 | # Configuration 387 | # ---------------------------------------- 388 | def set_level(self, level): 389 | self.level = level 390 | 391 | def set_comm(self, comm): 392 | self.comm = comm 393 | 394 | def get_dir(self): 395 | return self.dir 396 | 397 | def close(self): 398 | for fmt in self.output_formats: 399 | fmt.close() 400 | 401 | # Misc 402 | # ---------------------------------------- 403 | def _do_log(self, args): 404 | for fmt in self.output_formats: 405 | if isinstance(fmt, SeqWriter): 406 | fmt.writeseq(map(str, args)) 407 | 408 | 409 | def get_rank_without_mpi_import(): 410 | # check environment variables here instead of importing mpi4py 411 | # to avoid calling MPI_Init() when this module is imported 412 | for varname in ["PMI_RANK", "OMPI_COMM_WORLD_RANK"]: 413 | if varname in os.environ: 414 | return int(os.environ[varname]) 415 | return 0 416 | 417 | 418 | def mpi_weighted_mean(comm, local_name2valcount): 419 | """ 420 | Copied from: https://github.com/openai/baselines/blob/ 421 | ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/mpi_util.py#L110 422 | Perform a weighted average over dicts that are each on a different node 423 | Input: local_name2valcount: dict mapping key -> (value, count) 424 | Returns: key -> mean 425 | """ 426 | all_name2valcount = comm.gather(local_name2valcount) 427 | if comm.rank == 0: 428 | name2sum = defaultdict(float) 429 | name2count = defaultdict(float) 430 | for n2vc in all_name2valcount: 431 | for name, (val, count) in n2vc.items(): 432 | try: 433 | val = float(val) 434 | except ValueError: 435 | if comm.rank == 0: 436 | warnings.warn( 437 | "WARNING: tried to compute mean on non-float {}={}".format( 438 | name, val 439 | ) 440 | ) 441 | else: 442 | name2sum[name] += val * count 443 | name2count[name] += count 444 | return {name: name2sum[name] / name2count[name] for name in name2sum} 445 | else: 446 | return {} 447 | 448 | 449 | def configure(dir=None, format_strs=None, comm=None, log_suffix=""): 450 | """ 451 | If comm is provided, average all numerical stats across that comm 452 | """ 453 | if dir is None: 454 | dir = os.getenv("OPENAI_LOGDIR") 455 | if dir is None: 456 | dir = osp.join( 457 | tempfile.gettempdir(), 458 | datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f"), 459 | ) 460 | assert isinstance(dir, str) 461 | dir = os.path.expanduser(dir) 462 | os.makedirs(os.path.expanduser(dir), exist_ok=True) 463 | 464 | rank = get_rank_without_mpi_import() 465 | if rank > 0: 466 | log_suffix = log_suffix + "-rank%03i" % rank 467 | 468 | if format_strs is None: 469 | if rank == 0: 470 | format_strs = os.getenv("OPENAI_LOG_FORMAT", "stdout,log,csv").split(",") 471 | else: 472 | format_strs = os.getenv("OPENAI_LOG_FORMAT_MPI", "log").split(",") 473 | format_strs = filter(None, format_strs) 474 | output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs] 475 | 476 | Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm) 477 | if output_formats: 478 | log("Logging to %s" % dir) 479 | 480 | 481 | def _configure_default_logger(): 482 | configure() 483 | Logger.DEFAULT = Logger.CURRENT 484 | 485 | 486 | def reset(): 487 | if Logger.CURRENT is not Logger.DEFAULT: 488 | Logger.CURRENT.close() 489 | Logger.CURRENT = Logger.DEFAULT 490 | log("Reset logger") 491 | 492 | 493 | @contextmanager 494 | def scoped_configure(dir=None, format_strs=None, comm=None): 495 | prevlogger = Logger.CURRENT 496 | configure(dir=dir, format_strs=format_strs, comm=comm) 497 | try: 498 | yield 499 | finally: 500 | Logger.CURRENT.close() 501 | Logger.CURRENT = prevlogger 502 | -------------------------------------------------------------------------------- /model/utils/losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for various likelihood-based losses. These are ported from the original 3 | Ho et al. diffusion models codebase: 4 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py 5 | """ 6 | 7 | import numpy as np 8 | import torch as th 9 | 10 | 11 | def normal_kl(mean1, logvar1, mean2, logvar2): 12 | """ 13 | Compute the KL divergence between two gaussians. 14 | 15 | Shapes are automatically broadcasted, so batches can be compared to 16 | scalars, among other use cases. 17 | """ 18 | tensor = None 19 | for obj in (mean1, logvar1, mean2, logvar2): 20 | if isinstance(obj, th.Tensor): 21 | tensor = obj 22 | break 23 | assert tensor is not None, "at least one argument must be a Tensor" 24 | 25 | # Force variances to be Tensors. Broadcasting helps convert scalars to 26 | # Tensors, but it does not work for th.exp(). 27 | logvar1, logvar2 = [ 28 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) 29 | for x in (logvar1, logvar2) 30 | ] 31 | 32 | # print(logvar2.shape) 33 | # temp1 = 0.5 * (-1.0 + logvar2 - logvar1 + th.exp(logvar1 - logvar2)) 34 | # print(f'const = {temp1.mean()}, coef={(th.exp(-logvar2) * 0.5).mean()}, 35 | # mse={((mean1 - mean2) ** 2).mean().item()}') 36 | 37 | return 0.5 * ( 38 | -1.0 39 | + logvar2 40 | - logvar1 41 | + th.exp(logvar1 - logvar2) 42 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2) 43 | ) 44 | 45 | 46 | def approx_standard_normal_cdf(x): 47 | """ 48 | A fast approximation of the cumulative distribution function of the 49 | standard normal. 50 | """ 51 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) 52 | 53 | 54 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 55 | """ 56 | Compute the log-likelihood of a Gaussian distribution discretizing to a 57 | given image. 58 | 59 | :param x: the target images. It is assumed that this was uint8 values, 60 | rescaled to the range [-1, 1]. 61 | :param means: the Gaussian mean Tensor. 62 | :param log_scales: the Gaussian log stddev Tensor. 63 | :return: a tensor like x of log probabilities (in nats). 64 | """ 65 | assert x.shape == means.shape == log_scales.shape 66 | centered_x = x - means 67 | inv_stdv = th.exp(-log_scales) 68 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 69 | cdf_plus = approx_standard_normal_cdf(plus_in) 70 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 71 | cdf_min = approx_standard_normal_cdf(min_in) 72 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 73 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 74 | cdf_delta = cdf_plus - cdf_min 75 | log_probs = th.where( 76 | x < -0.999, 77 | log_cdf_plus, 78 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 79 | ) 80 | assert log_probs.shape == x.shape 81 | return log_probs 82 | 83 | 84 | def gaussian_density(x, *, means, log_scales): 85 | from torch.distributions import Normal 86 | 87 | normal_dist = Normal(means, log_scales.exp()) 88 | logp = normal_dist.log_prob(x) 89 | return logp 90 | 91 | 92 | def discretized_text_log_likelihood(x, *, means, log_scales): 93 | """ 94 | Compute the log-likelihood of a Gaussian distribution discretizing to a 95 | given image. 96 | 97 | :param x: the target images. It is assumed that this was uint8 values, 98 | rescaled to the range [-1, 1]. 99 | :param means: the Gaussian mean Tensor. 100 | :param log_scales: the Gaussian log stddev Tensor. 101 | :return: a tensor like x of log probabilities (in nats). 102 | """ 103 | print(x.shape, means.shape) 104 | # assert x.shape == means.shape == log_scales.shape 105 | print(x, means) 106 | centered_x = x - means 107 | inv_stdv = th.exp(-log_scales) 108 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 109 | cdf_plus = approx_standard_normal_cdf(plus_in) 110 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 111 | cdf_min = approx_standard_normal_cdf(min_in) 112 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 113 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 114 | cdf_delta = cdf_plus - cdf_min 115 | log_probs = th.where( 116 | x < -0.999, 117 | log_cdf_plus, 118 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 119 | ) 120 | assert log_probs.shape == x.shape 121 | return log_probs 122 | -------------------------------------------------------------------------------- /model/utils/nn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various utilities for neural networks. 3 | """ 4 | 5 | import math 6 | 7 | import torch as th 8 | import torch.nn as nn 9 | 10 | 11 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 12 | class SiLU(nn.Module): 13 | def forward(self, x): 14 | return x * th.sigmoid(x) 15 | 16 | 17 | class GroupNorm32(nn.GroupNorm): 18 | def forward(self, x): 19 | return super().forward(x.float()).type(x.dtype) 20 | 21 | 22 | def linear(*args, **kwargs): 23 | """ 24 | Create a linear module. 25 | """ 26 | return nn.Linear(*args, **kwargs) 27 | 28 | 29 | def avg_pool_nd(dims, *args, **kwargs): 30 | """ 31 | Create a 1D, 2D, or 3D average pooling module. 32 | """ 33 | if dims == 1: 34 | return nn.AvgPool1d(*args, **kwargs) 35 | elif dims == 2: 36 | return nn.AvgPool2d(*args, **kwargs) 37 | elif dims == 3: 38 | return nn.AvgPool3d(*args, **kwargs) 39 | raise ValueError(f"unsupported dimensions: {dims}") 40 | 41 | 42 | def update_ema(target_params, source_params, rate=0.99): 43 | """ 44 | Update target parameters to be closer to those of source parameters using 45 | an exponential moving average. 46 | 47 | :param target_params: the target parameter sequence. 48 | :param source_params: the source parameter sequence. 49 | :param rate: the EMA rate (closer to 1 means slower). 50 | """ 51 | for targ, src in zip(target_params, source_params): 52 | targ.detach().mul_(rate).add_(src, alpha=1 - rate) 53 | 54 | 55 | def zero_module(module): 56 | """ 57 | Zero out the parameters of a module and return it. 58 | """ 59 | for p in module.parameters(): 60 | p.detach().zero_() 61 | return module 62 | 63 | 64 | def scale_module(module, scale): 65 | """ 66 | Scale the parameters of a module and return it. 67 | """ 68 | for p in module.parameters(): 69 | p.detach().mul_(scale) 70 | return module 71 | 72 | 73 | def mean_flat(tensor): 74 | """ 75 | Take the mean over all non-batch dimensions. 76 | """ 77 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 78 | 79 | 80 | def normalization(channels): 81 | """ 82 | Make a standard normalization layer. 83 | 84 | :param channels: number of input channels. 85 | :return: an nn.Module for normalization. 86 | """ 87 | return GroupNorm32(32, channels) 88 | 89 | 90 | def timestep_embedding(timesteps, dim, max_period=10000): 91 | """ 92 | Create sinusoidal timestep embeddings. 93 | 94 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 95 | These may be fractional. 96 | :param dim: the dimension of the output. 97 | :param max_period: controls the minimum frequency of the embeddings. 98 | :return: an [N x dim] Tensor of positional embeddings. 99 | """ 100 | half = dim // 2 101 | freqs = th.exp( 102 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half 103 | ).to(device=timesteps.device) 104 | args = timesteps[:, None].float() * freqs[None] 105 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) 106 | if dim % 2: 107 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) 108 | return embedding 109 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.30.1 2 | beautifulsoup4==4.12.3 3 | black==24.4.2 4 | flake8==7.0.0 5 | huggingface-hub==0.27.1 6 | isort==5.13.2 7 | matplotlib==3.9.0 8 | numpy==1.26.4 9 | pandas==2.2.3 10 | peft==0.11.1 11 | regex==2024.11.6 12 | requests==2.32.2 13 | sentencepiece==0.2.0 14 | torch==2.5.1 15 | tokenizers==0.20.2 16 | torch==2.3.0 17 | tqdm==4.66.4 18 | transformers @ git+https://github.com/huggingface/transformers@bdb9106f247fca48a71eb384be25dbbd29b065a8 19 | triton==3.1.0 20 | urllib3==2.2.2 21 | wandb==0.17.0 22 | Wikidata==0.7.0 23 | wikipedia==1.4.0 24 | xformers==0.0.26.post1 25 | -------------------------------------------------------------------------------- /trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | 4 | class Trainer(ABC): 5 | @abstractmethod 6 | def train( 7 | self, 8 | save_path, 9 | model, 10 | train_dataloader, 11 | val_dataloader, 12 | optimizer, 13 | scheduler, 14 | epochs, 15 | **kwargs 16 | ): 17 | # training process 18 | raise NotImplementedError 19 | 20 | @abstractmethod 21 | def evaluate(self, model, test_dataloader, tokenizer, **kwargs): 22 | # testing process 23 | raise NotImplementedError 24 | -------------------------------------------------------------------------------- /trainer/trainer_musee.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | import wandb 7 | from torch import nn 8 | from torch.nn.functional import softmax 9 | from transformers import T5Config, T5ForConditionalGeneration 10 | 11 | from . import Trainer 12 | 13 | 14 | def decode_logits(logits, tokenizer): 15 | probs = softmax(logits, dim=-1) 16 | 17 | # Get the most likely token IDs 18 | predicted_ids = torch.argmax(probs, dim=-1) 19 | print("predicted_ids:", predicted_ids) 20 | 21 | # Decode the token IDs to tokens 22 | decoded_tokens = [] 23 | for i in range(predicted_ids.shape[0]): 24 | decoded_sequence = tokenizer.decode(predicted_ids[i], skip_special_tokens=True) 25 | decoded_tokens.append(decoded_sequence.split()) 26 | 27 | return decoded_tokens 28 | 29 | 30 | def replace_with_closest_embedding( 31 | predict_pk_ids, added_ent_type_tokens, added_pk_tokens, model, device 32 | ): 33 | # Getting the embeddings from the model and moving them to the specified device 34 | embeddings = model.t5_model.shared.weight.to(device) 35 | 36 | def find_closest_token(target_token, embeddings, allowed_tokens, device): 37 | target_embedding = ( 38 | embeddings[target_token].unsqueeze(0).to(device) 39 | ) # Get the embedding of the target token 40 | allowed_embeddings = embeddings[allowed_tokens].to( 41 | device 42 | ) # Get embeddings of allowed tokens 43 | distances = torch.norm(target_embedding - allowed_embeddings, dim=1) 44 | closest_idx = distances.argmin() 45 | closest_token = allowed_tokens[closest_idx] 46 | 47 | return closest_token 48 | 49 | for idx, sequence in enumerate(predict_pk_ids): 50 | for token_idx, token in enumerate(sequence): 51 | if token_idx == 0: 52 | allowed_tokens = added_ent_type_tokens 53 | else: 54 | allowed_tokens = added_pk_tokens + [1] # Adding end token 55 | 56 | closest_token = find_closest_token( 57 | token, embeddings, allowed_tokens, device 58 | ) 59 | predict_pk_ids[idx, token_idx] = closest_token 60 | 61 | return predict_pk_ids 62 | 63 | 64 | class Predictor_E_Pk_Pv(nn.Module): 65 | def __init__( 66 | self, 67 | pretrained_model_name, 68 | max_seq_length, 69 | max_num_entity, 70 | tokenizer, 71 | ): 72 | super(Predictor_E_Pk_Pv, self).__init__() 73 | 74 | self.tokenizer = tokenizer 75 | self.max_seq_length = max_seq_length 76 | self.max_num_entity = max_num_entity 77 | 78 | self.t5_model = T5ForConditionalGeneration.from_pretrained( 79 | pretrained_model_name 80 | ) 81 | 82 | def _prompt_decoder_return_logits( 83 | self, 84 | prompt_tokenized, 85 | input_ids, 86 | encoder_outputs, 87 | attention_mask, 88 | labels, 89 | added_ent_type_tokens=None, 90 | added_pk_tokens=None, 91 | ): 92 | start_token_id = self.tokenizer.pad_token_id 93 | end_token_id = self.tokenizer.eos_token_id 94 | start_tokens = torch.full( 95 | (labels.size(0), 1), start_token_id, dtype=torch.long, device=labels.device 96 | ) 97 | 98 | combined_decoder_input_ids = torch.cat( 99 | [start_tokens, prompt_tokenized, labels], dim=-1 100 | ) 101 | # Move all non-zero of each row to the beginning, except the first position be 0 102 | temp_result = torch.zeros_like(combined_decoder_input_ids) 103 | for i, row in enumerate(combined_decoder_input_ids): 104 | non_zeros = row[row != 0] # Extract non-zero elements 105 | temp_result[i, :1] = 0 # Keep the first element as 0 106 | temp_result[i, 1 : 1 + len(non_zeros)] = ( 107 | non_zeros # Place non-zero elements 108 | ) 109 | combined_decoder_input_ids = temp_result 110 | 111 | attention_mask_decoder = ( 112 | combined_decoder_input_ids != self.tokenizer.pad_token_id 113 | ).long() 114 | attention_mask_decoder[:, 0] = 1 115 | 116 | # print("attention_mask_decoder:", attention_mask_decoder.sum(1)) 117 | # print("combined_decoder_input_ids:", combined_decoder_input_ids) 118 | # print("attention_mask_decoder:", attention_mask_decoder) 119 | 120 | # if encoder_outputs is not None: 121 | logits = self.t5_model( 122 | # input_ids=input_ids, 123 | encoder_outputs=( 124 | encoder_outputs, 125 | ), # convert the encoder_outputs back to a tuple 126 | attention_mask=attention_mask, 127 | decoder_input_ids=combined_decoder_input_ids, 128 | decoder_attention_mask=attention_mask_decoder, 129 | ).logits[ 130 | :, :-1 131 | ] # (batch_size, prompt_len + tgt_seq_len, vocab) 132 | # print("logits encoder_outputs:", logits.sum()) 133 | 134 | # else: 135 | # logits = self.t5_model( 136 | # input_ids=input_ids, 137 | # # encoder_outputs=(encoder_outputs,), # convert the encoder_outputs back to a tuple 138 | # attention_mask=attention_mask, 139 | # decoder_input_ids=combined_decoder_input_ids, 140 | # decoder_attention_mask=attention_mask_decoder, 141 | # ).logits[:, :-1] # (batch_size, prompt_len + tgt_seq_len, vocab) 142 | # print("logits input_id:", logits.sum()) 143 | 144 | """If output is: 123000 || abc000 """ 145 | # logits = logits[:, -labels.size(1):] # (batch_size, tgt_seq_len, vocab) 146 | 147 | """If output is: 123abc || 000000 """ 148 | # print("length:", labels.size(1)) 149 | # Compute the start index for each row based on the number of non-zero values 150 | start_indices = (prompt_tokenized != 0).long().sum(dim=1) 151 | # print("start_indices:", start_indices) 152 | indices_range = torch.arange(labels.size(1)).unsqueeze(0).to( 153 | labels.device 154 | ) + start_indices.unsqueeze(1) 155 | # print("indices_range:", indices_range) 156 | batch_indices = torch.arange(logits.size(0)).unsqueeze(1).to(labels.device) 157 | # print("batch_indices:", batch_indices) 158 | 159 | logits = logits[ 160 | batch_indices, indices_range 161 | ] # (batch_size, tgt_seq_len, vocab) 162 | 163 | if added_ent_type_tokens is not None and added_pk_tokens is not None: 164 | # Apply constraints: candidates can only be {ent_type, pk, eos} 165 | large_negative = -1e9 166 | mask_token = torch.full_like(logits, large_negative).to(labels.device) 167 | # Apply constraints to all positions 168 | for token_set in [added_ent_type_tokens, added_pk_tokens, [end_token_id]]: 169 | mask_token[:, :, token_set] = 0 170 | 171 | logits += mask_token 172 | 173 | return logits 174 | 175 | def extract_entity_names(self, label_string, tokenizer): 176 | sep_token = "" 177 | sep_token_id = tokenizer.convert_tokens_to_ids(sep_token) 178 | token_ids = tokenizer.encode(label_string, add_special_tokens=False) 179 | 180 | entities = [] 181 | current_entity = [] 182 | for token_id in token_ids: 183 | if token_id == sep_token_id: 184 | if current_entity: 185 | entity_name = tokenizer.decode( 186 | current_entity, skip_special_tokens=True 187 | ) 188 | entities.append(entity_name) 189 | current_entity = [] 190 | else: 191 | current_entity.append(token_id) 192 | 193 | return entities 194 | 195 | def inference_generate_ids( 196 | self, 197 | prompt_tokenized, 198 | input_ids, 199 | attention_mask, 200 | max_length, 201 | added_ent_type_tokens=None, 202 | added_pk_tokens=None, 203 | ): 204 | batch_size = input_ids.size(0) 205 | # Initialize a tensor of ones 206 | all_predict_ids = torch.ones( 207 | batch_size, max_length, dtype=torch.long, device=input_ids.device 208 | ) 209 | 210 | # Determine the suppress_tokens based on added_ent_type_tokens and added_pk_tokens 211 | suppress_tokens = None 212 | if added_ent_type_tokens is not None and added_pk_tokens is not None: 213 | allowed_tokens = set( 214 | added_ent_type_tokens + added_pk_tokens + [1] 215 | ) # Including end token 216 | # print("allowed_tokens:", allowed_tokens) 217 | all_tokens = set(range(self.t5_model.config.vocab_size)) 218 | suppress_tokens = list(all_tokens - allowed_tokens) 219 | 220 | for idx in range(batch_size): 221 | single_prompt = prompt_tokenized[idx, :].unsqueeze(0) 222 | single_cleaned_prompt = single_prompt[single_prompt != 0].unsqueeze(0) 223 | 224 | single_input_ids = input_ids[idx, :].unsqueeze(0) 225 | single_attention_mask = attention_mask[idx, :].unsqueeze(0) 226 | 227 | forced_decoder_ids = [ 228 | [index + 1, element.item()] 229 | for index, element in enumerate(single_prompt[0]) 230 | if element.item() != 0 231 | ] 232 | 233 | # Generate predictions with suppress_tokens if applicable 234 | generate_args = { 235 | "input_ids": single_input_ids, 236 | "attention_mask": single_attention_mask, 237 | "forced_decoder_ids": forced_decoder_ids, 238 | "max_length": max_length, 239 | } 240 | if suppress_tokens is not None: 241 | generate_args["suppress_tokens"] = suppress_tokens 242 | 243 | predict_ids = self.t5_model.generate(**generate_args) 244 | # if added_ent_type_tokens is not None and added_pk_tokens is not None: 245 | # print("single_prompt:", single_prompt) 246 | # print("predict_ids:", predict_ids) 247 | 248 | prompt_size = len(forced_decoder_ids) 249 | trimmed_predict_ids = predict_ids[ 250 | :, prompt_size + 1 : 251 | ] # +1 due to the first generated token always being 0 252 | 253 | output_length = trimmed_predict_ids.size(1) 254 | all_predict_ids[idx, :output_length] = trimmed_predict_ids.squeeze(0) 255 | 256 | return all_predict_ids 257 | 258 | def forward( 259 | self, 260 | input_ids, # (b, seq_len) 261 | labels_ent_name, # (b, max_num_Entity * 6) 262 | real_labels_ent_name, # (b, max_num_Entity * 6) 263 | labels_pk, # (b, max_num_Entity, num_all_pks+2) 264 | labels_pv, # (b, max_num_Entity, num_all_pks, max_prop_len) 265 | attention_mask, 266 | attention_mask_ent_name, 267 | attention_mask_pk, 268 | attention_mask_pv, 269 | max_len_pv, 270 | device, 271 | added_ent_type_tokens, 272 | added_pk_tokens, 273 | mode="train", 274 | ): 275 | # Encode the input sequence once 276 | encoder_outputs = self.t5_model.encoder( 277 | input_ids=input_ids, 278 | attention_mask=attention_mask, 279 | output_hidden_states=False, 280 | return_dict=False, 281 | )[ 282 | 0 283 | ] # it is a tuple (xxx,), we use [0] to choose xxx 284 | 285 | """Step1: logits_ent""" 286 | prompt_step1 = f"predict entities" 287 | prompt_tokenized = ( 288 | self.tokenizer.encode( 289 | prompt_step1, add_special_tokens=False, return_tensors="pt" 290 | ) 291 | .repeat(input_ids.size(0), 1) 292 | .to(input_ids.device) 293 | ) # (b, x) 294 | 295 | if mode == "train": 296 | logits_ent = self._prompt_decoder_return_logits( 297 | prompt_tokenized, 298 | input_ids, 299 | encoder_outputs, 300 | attention_mask, 301 | labels_ent_name, 302 | ) # (b, max_num_Entity * 6, vocab) 303 | elif mode == "test": 304 | predict_ent_ids = self.inference_generate_ids( 305 | prompt_tokenized, input_ids, attention_mask, self.max_num_entity * 6 306 | ) 307 | 308 | """Step2: logits_pk""" 309 | max_prompt_length = 20 310 | batch_size, seq_len = input_ids.shape 311 | 312 | input_ids_ent_batch = [] 313 | enc_outputs_ent_batch = [] 314 | attention_mask_ent_batch = [] 315 | 316 | # Initialize a list for storing padded decoder inputs 317 | prompt_padded_batch_step2 = [] 318 | 319 | for sample_idx in range(batch_size): 320 | if mode == "train": 321 | # Extract the entity names from the ground truth labels 322 | entity_names = self.extract_entity_names( 323 | real_labels_ent_name[sample_idx], self.tokenizer 324 | ) 325 | elif mode == "test": 326 | # Decode the predicted entity names from Step 1 327 | predicted_ent = [ 328 | self.tokenizer.decode(ids, skip_special_tokens=False) 329 | for ids in predict_ent_ids 330 | ] 331 | entity_names = self.extract_entity_names( 332 | predicted_ent[sample_idx], self.tokenizer 333 | ) 334 | if entity_names == []: 335 | entity_names = ["Fail to predict"] 336 | 337 | for entity_name in entity_names: 338 | # Format the input for the T5 decoder 339 | prompt_sample = f"predict type and properties {entity_name}" 340 | # print("prompt_sample step2:", prompt_sample) 341 | prompt_sample_tokenized = self.tokenizer.encode( 342 | prompt_sample, add_special_tokens=False, return_tensors="pt" 343 | ).to(input_ids.device) 344 | 345 | # Pad the tokenized input to the max_decoder_length 346 | prompt_padded_sample = torch.nn.functional.pad( 347 | prompt_sample_tokenized, 348 | # (max_prompt_length - prompt_sample_tokenized.shape[1], 0), # Add padding at the beginning 349 | (0, max_prompt_length - prompt_sample_tokenized.shape[1]), 350 | value=self.tokenizer.pad_token_id, 351 | ) 352 | 353 | # Add to the list 354 | prompt_padded_batch_step2.append(prompt_padded_sample) 355 | 356 | # Repeat input_ids and attention_mask for each entity 357 | repeated_input_ids = input_ids[sample_idx].unsqueeze(0) 358 | repeated_enc_outputs = encoder_outputs[sample_idx].unsqueeze(0) 359 | repeated_attention_mask = attention_mask[sample_idx].unsqueeze(0) 360 | input_ids_ent_batch.append(repeated_input_ids) 361 | enc_outputs_ent_batch.append(repeated_enc_outputs) 362 | attention_mask_ent_batch.append(repeated_attention_mask) 363 | 364 | # Concatenate the repeated input_ids and attention masks to form a batch 365 | input_ids_ent_batch = torch.cat( 366 | input_ids_ent_batch, dim=0 367 | ) # (num_ent_in_batch, 512) 368 | enc_outputs_ent_batch = torch.cat( 369 | enc_outputs_ent_batch, dim=0 370 | ) # (num_ent_in_batch, 512, 512) 371 | attention_mask_ent_batch = torch.cat( 372 | attention_mask_ent_batch, dim=0 373 | ) # (num_ent_in_batch, 512) 374 | prompt_padded_batch_step2 = torch.cat( 375 | prompt_padded_batch_step2, dim=0 376 | ) # (num_ent_in_batch, max_prompt_length) 377 | 378 | if mode == "train": 379 | # remove all 0 tensors, since those are only for padding 380 | labels_pk_flatten = labels_pk[ 381 | (labels_pk != 0).any(dim=2) 382 | ] # (num_ent_in_batch, tgt_seq_len) 383 | logits_pk = self._prompt_decoder_return_logits( 384 | prompt_padded_batch_step2, 385 | input_ids_ent_batch, 386 | enc_outputs_ent_batch, 387 | attention_mask_ent_batch, 388 | labels_pk_flatten, 389 | added_ent_type_tokens=added_ent_type_tokens, 390 | added_pk_tokens=list(set(added_pk_tokens) - set(["pk_entity_name"])), 391 | ) # (num_ent_in_batch, tgt_seq_len, vocab) 392 | elif mode == "test": 393 | predict_pk_ids = self.inference_generate_ids( 394 | prompt_padded_batch_step2, 395 | input_ids_ent_batch, 396 | attention_mask_ent_batch, 397 | max_prompt_length + 10, 398 | added_ent_type_tokens=added_ent_type_tokens, 399 | added_pk_tokens=list(set(added_pk_tokens) - set(["pk_entity_name"])), 400 | ) 401 | # print("predict_pk_ids:", predict_pk_ids) 402 | 403 | """Step3: logits_pv""" 404 | max_prompt_length = 30 # Set a suitable max length for the decoder prompts 405 | 406 | input_ids_pv_batch = [] 407 | enc_outputs_pv_batch = [] 408 | attention_mask_pv_batch = [] 409 | prompt_padded_batch_step3 = [] 410 | 411 | for sample_idx in range(batch_size): 412 | if mode == "train": 413 | # Extract the entity names from the ground truth labels 414 | entity_names = self.extract_entity_names( 415 | real_labels_ent_name[sample_idx], self.tokenizer 416 | ) 417 | elif mode == "test": 418 | # Decode the predicted entity names from Step 1 419 | predicted_ent = [ 420 | self.tokenizer.decode(ids, skip_special_tokens=False) 421 | for ids in predict_ent_ids 422 | ] 423 | entity_names = self.extract_entity_names( 424 | predicted_ent[sample_idx], self.tokenizer 425 | ) 426 | labels_pk = predict_pk_ids.unsqueeze(0) 427 | 428 | for entity_idx, entity_name in enumerate(entity_names): 429 | # Get the entity type 430 | entity_type_id = labels_pk[sample_idx, entity_idx, 0] 431 | entity_type_token = self.tokenizer.decode([entity_type_id]) 432 | 433 | # Iterate through each property key for the entity, starting from column index 1 in labels_pk 434 | for pk_idx in range(1, labels_pk.size(2) - 1): 435 | pk_id = labels_pk[sample_idx, entity_idx, pk_idx] 436 | if pk_id in [ 437 | self.tokenizer.pad_token_id, 438 | self.tokenizer.eos_token_id, 439 | ]: 440 | continue 441 | 442 | # Retrieve property key token 443 | pk_token = self.tokenizer.decode([pk_id]) 444 | 445 | # Format the input for the T5 decoder 446 | prompt_sample = f"Predict property value for {entity_name} {entity_type_token} {pk_token}" 447 | # print("prompt_sample step3:", prompt_sample) 448 | prompt_sample_tokenized = self.tokenizer.encode( 449 | prompt_sample, add_special_tokens=False, return_tensors="pt" 450 | ).to(input_ids.device) 451 | 452 | # Pad the tokenized input to the max_prompt_length 453 | prompt_padded_sample = torch.nn.functional.pad( 454 | prompt_sample_tokenized, 455 | # (max_prompt_length - prompt_sample_tokenized.shape[1], 0), 456 | (0, max_prompt_length - prompt_sample_tokenized.shape[1]), 457 | value=self.tokenizer.pad_token_id, 458 | ) 459 | 460 | # Add to the list 461 | prompt_padded_batch_step3.append(prompt_padded_sample) 462 | 463 | # Repeat input_ids and attention_mask for each property key 464 | repeated_input_ids = input_ids[sample_idx].unsqueeze(0) 465 | repeated_enc_outputs = encoder_outputs[sample_idx].unsqueeze(0) 466 | repeated_attention_mask = attention_mask[sample_idx].unsqueeze(0) 467 | input_ids_pv_batch.append(repeated_input_ids) 468 | enc_outputs_pv_batch.append(repeated_enc_outputs) 469 | attention_mask_pv_batch.append(repeated_attention_mask) 470 | 471 | # Concatenate the repeated input_ids to form a batch 472 | if ( 473 | input_ids_pv_batch != [] 474 | ): # Edge case for inference, sometimes the model predicts nothing for step2 475 | input_ids_pv_batch = torch.cat( 476 | input_ids_pv_batch, dim=0 477 | ) # (num_pair_batch, seq_len) 478 | enc_outputs_pv_batch = torch.cat( 479 | enc_outputs_pv_batch, dim=0 480 | ) # (num_pair_batch, seq_len) 481 | attention_mask_pv_batch = torch.cat( 482 | attention_mask_pv_batch, dim=0 483 | ) # (num_pair_batch, seq_len) 484 | prompt_padded_batch_step3 = torch.cat( 485 | prompt_padded_batch_step3, dim=0 486 | ) # (num_pair_batch, max_prompt_length) 487 | 488 | # if enc_outputs_pv_batch != []: # Edge case for inference, sometimes the model predicts nothing for step2 489 | # 490 | # attention_mask_pv_batch = torch.cat(attention_mask_pv_batch, dim=0) # (num_pair_batch, seq_len) 491 | # prompt_padded_batch_step3 = torch.cat(prompt_padded_batch_step3, 492 | # dim=0) # (num_pair_batch, max_prompt_length) 493 | 494 | if mode == "train": 495 | if input_ids_pv_batch != []: 496 | # remove all 0 tensors, since those are only for padding 497 | labels_pv_flatten = labels_pv[ 498 | (labels_pv != self.tokenizer.pad_token_id).any(dim=3) 499 | ] # Flatten the labels_pv tensor 500 | # Predict the property values 501 | logits_pv = self._prompt_decoder_return_logits( 502 | prompt_padded_batch_step3, 503 | input_ids_pv_batch, 504 | enc_outputs_pv_batch, 505 | attention_mask_pv_batch, 506 | labels_pv_flatten, 507 | ) # (num_pair_batch, max_prop_len, vocab) 508 | else: # There is no other pk in addition to pk_ent_name. So we do not need to predict anything 509 | logits_pv = None 510 | elif mode == "test": 511 | if ( 512 | input_ids_pv_batch != [] 513 | ): # Edge case for inference, sometimes the model predicts nothing for step2 514 | # Generate the property value predictions 515 | predict_pv_ids = self.inference_generate_ids( 516 | prompt_padded_batch_step3, 517 | input_ids_pv_batch, 518 | attention_mask_pv_batch, 519 | max_length=max_prompt_length + 20, 520 | ) 521 | else: 522 | predict_pv_ids = torch.tensor([[0]]) 523 | 524 | if mode == "train": 525 | return logits_ent, logits_pk, logits_pv 526 | elif mode == "test": 527 | return predict_ent_ids, predict_pk_ids, predict_pv_ids 528 | 529 | 530 | class Trainer_E_Pk_Pv(Trainer): 531 | @staticmethod 532 | def calculate_loss( 533 | logits_ent, 534 | logits_pk, 535 | logits_pv, 536 | labels_ent_name, # (b, max_num_Entity * 6) 537 | labels_pk, # (b, max_num_Entity, num_all_pks+2) 538 | labels_pv, # (b, max_num_Entity, num_all_pks, max_prop_len) 539 | attention_mask_ent_name, 540 | attention_mask_pk, 541 | attention_mask_pv, 542 | added_ent_type_tokens, 543 | loss_mode, 544 | device, 545 | ): 546 | def calculate_loss_for_step( 547 | logits, 548 | labels, 549 | attention_mask, 550 | criterion, 551 | loss_mode, 552 | added_ent_type_tokens=None, 553 | end_token_id=1, 554 | sep_token_id=32098, 555 | ): 556 | # Flatten logits and labels 557 | logits_flatten = logits.reshape(-1, logits.size(-1)) 558 | labels_flatten = labels.view(-1) 559 | 560 | # Compute token-wise loss 561 | loss_tokenwise = criterion(logits_flatten, labels_flatten) 562 | 563 | # Apply attention mask 564 | loss_masked = loss_tokenwise.view_as(labels) * attention_mask 565 | 566 | # Initialize weights with ones 567 | weights = torch.ones_like(labels, dtype=torch.float, device=labels.device) 568 | 569 | # Increase weight for sep_token_id 570 | weights[labels == sep_token_id] = 2 571 | # weights[labels == end_token_id] = 2 572 | 573 | # # Increase weights for tokens in added_ent_type_tokens, if specified 574 | # if added_ent_type_tokens is not None: 575 | # for token_id in added_ent_type_tokens: 576 | # weights[labels == token_id] = 3 # Or another weight factor as desired 577 | 578 | # Apply weights to the loss 579 | loss_weighted = loss_masked * weights 580 | 581 | # Compute average loss 582 | valid_tokens_mask = attention_mask == 1 583 | if loss_mode == "mean": 584 | loss = ( 585 | loss_weighted.sum(dim=-1) / valid_tokens_mask.sum(dim=-1) 586 | ).mean() 587 | elif loss_mode == "sum": 588 | loss = (loss_weighted.sum(dim=-1) / valid_tokens_mask.sum(dim=-1)).sum() 589 | 590 | return loss 591 | 592 | criterion = torch.nn.CrossEntropyLoss( 593 | reduction="none" 594 | ) # Initialize without parameters 595 | 596 | """Step1: loss_ent""" 597 | # labels_ent_tokenized # (b, max_num_Entity * 3) 598 | loss_ent = calculate_loss_for_step( 599 | logits_ent, labels_ent_name, attention_mask_ent_name, criterion, loss_mode 600 | ) 601 | 602 | """Step2: loss_pk""" 603 | labels_pk_flatten = labels_pk[ 604 | (labels_pk != 0).any(dim=2) 605 | ] # (num_ent_batch, tgt_seq_len) 606 | attention_mask_pk_flatten = attention_mask_pk[ 607 | (attention_mask_pk != 0).any(dim=2) 608 | ] # (num_ent_batch, tgt_seq_len) 609 | loss_pk = calculate_loss_for_step( 610 | logits_pk, 611 | labels_pk_flatten, 612 | attention_mask_pk_flatten, 613 | criterion, 614 | loss_mode, 615 | added_ent_type_tokens, 616 | ) 617 | 618 | """Step3: loss_pv""" 619 | if logits_pv is not None: 620 | labels_pv_flatten = labels_pv[ 621 | (labels_pv != 0).any(dim=3) 622 | ] # (num_pair_batch, pv_seq_len) 623 | attention_mask_pv_flatten = attention_mask_pv[ 624 | (attention_mask_pv != 0).any(dim=3) 625 | ] # (num_pair_batch, pv_seq_len) 626 | loss_pv = calculate_loss_for_step( 627 | logits_pv, 628 | labels_pv_flatten, 629 | attention_mask_pv_flatten, 630 | criterion, 631 | loss_mode, 632 | ) 633 | else: 634 | loss_pv = torch.tensor(0).to(loss_pk.device) 635 | 636 | return loss_ent, loss_pk, loss_pv 637 | 638 | @staticmethod 639 | def compute_batch_loss( 640 | batch, 641 | model, 642 | added_ent_type_tokens, 643 | added_pk_tokens, 644 | loss_mode, 645 | device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"), 646 | ): 647 | input_ids = batch["input_ids"].to(device) # (b, seq_len) 648 | labels_ent_name = batch["labels_ent_name"].to(device) # (b, max_num_Entity * 6) 649 | real_labels_ent_name = batch["real_labels_ent_name"] # (b, max_num_Entity * 6) 650 | labels_pk = batch["labels_pk"].to(device) # (b, max_num_Entity, num_all_pks+2) 651 | labels_pv = batch["labels_pv"].to( 652 | device 653 | ) # (b, max_num_Entity, num_all_pks, max_prop_len) 654 | attention_mask = batch["attention_mask"].to(device) 655 | attention_mask_ent_name = batch["attention_mask_ent_name"].to(device) 656 | attention_mask_pk = batch["attention_mask_pk"].to(device) 657 | attention_mask_pv = batch["attention_mask_pv"].to(device) 658 | max_len_pv = attention_mask_pv.shape[-1] 659 | 660 | # print("labels_ent_name:", labels_ent_name) 661 | # print("labels_pk:", labels_pk) 662 | # print("labels_pv:", labels_pv) 663 | 664 | logits_ent, logits_pk, logits_pv = model( 665 | input_ids, # (b, seq_len) 666 | labels_ent_name, # (b, max_num_Entity * 6) 667 | real_labels_ent_name, # (b, max_num_Entity * 6) 668 | labels_pk, # (b, max_num_Entity, num_all_pks+2) 669 | labels_pv, # (b, max_num_Entity, num_all_pks, max_prop_len) 670 | attention_mask, 671 | attention_mask_ent_name, 672 | attention_mask_pk, 673 | attention_mask_pv, 674 | max_len_pv, 675 | device, 676 | added_ent_type_tokens, 677 | added_pk_tokens, 678 | ) 679 | ( 680 | loss_ent, 681 | loss_pk, 682 | loss_pv, 683 | ) = Trainer_E_Pk_Pv.calculate_loss( 684 | logits_ent, 685 | logits_pk, 686 | logits_pv, 687 | labels_ent_name, # (b, max_num_Entity * 6) 688 | labels_pk, # (b, max_num_Entity, num_all_pks+2) 689 | labels_pv, # (b, max_num_Entity, num_all_pks, max_prop_len) 690 | attention_mask_ent_name, 691 | attention_mask_pk, 692 | attention_mask_pv, 693 | added_ent_type_tokens, 694 | loss_mode, 695 | device, 696 | ) 697 | return loss_ent, loss_pk, loss_pv 698 | 699 | @staticmethod 700 | def evaluate_full_dataloader( 701 | dataloader, 702 | model, 703 | added_ent_type_tokens, 704 | added_pk_tokens, 705 | loss_mode, 706 | device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"), 707 | ): 708 | model.eval() 709 | 710 | total_loss_ent, total_loss_pk, total_loss_pv, total_loss = 0, 0, 0, 0 711 | with torch.no_grad(): 712 | for batch in dataloader: 713 | loss_ent, loss_pk, loss_pv = Trainer_E_Pk_Pv.compute_batch_loss( 714 | batch, 715 | model, 716 | added_ent_type_tokens, 717 | added_pk_tokens, 718 | loss_mode, 719 | device, 720 | ) 721 | total_loss_ent += loss_ent.item() 722 | total_loss_pk += loss_pk.item() 723 | total_loss_pv += loss_pv.item() 724 | total_loss += loss_ent.item() + loss_pk.item() + loss_pv.item() 725 | return ( 726 | total_loss_ent / len(dataloader), 727 | total_loss_pk / len(dataloader), 728 | total_loss_pv / len(dataloader), 729 | total_loss / len(dataloader), 730 | ) 731 | 732 | def train( 733 | self, 734 | save_path, 735 | model, 736 | train_dataloader, 737 | val_dataloader, 738 | optimizer, 739 | scheduler, 740 | epochs, 741 | **kwargs, 742 | ): 743 | device = kwargs.get( 744 | "device", torch.device("cuda" if torch.cuda.is_available() else "cpu") 745 | ) 746 | log_wandb = kwargs.get("log_wandb", False) 747 | use_lora = kwargs.get("use_lora", "True") 748 | alpha = kwargs.get("alpha", 0.5) 749 | added_ent_type_tokens = kwargs.get("added_ent_type_tokens", None) 750 | added_pk_tokens = kwargs.get("added_pk_tokens", None) 751 | loss_mode = kwargs.get("loss_mode", None) 752 | report = kwargs.get("reporter", None) 753 | # # Start wandb 754 | # if log_wandb: 755 | # run = wandb.init(project="your-project-name", entity="your-entity-name") 756 | 757 | # Initialize variables for early stopping 758 | no_improve_step, no_improve_epochs = 0, 0 759 | min_train_loss, min_val_loss = float("inf"), float("inf") 760 | 761 | # Compute and log the initial loss before training 762 | print("Monitor Epoch loss...") 763 | ( 764 | avg_train_loss_ent, 765 | avg_train_loss_pk, 766 | avg_train_loss_pv, 767 | avg_train_loss, 768 | ) = Trainer_E_Pk_Pv.evaluate_full_dataloader( 769 | train_dataloader, 770 | model, 771 | added_ent_type_tokens, 772 | added_pk_tokens, 773 | loss_mode, 774 | device=device, 775 | ) 776 | ( 777 | avg_val_loss_ent, 778 | avg_val_loss_pk, 779 | avg_val_loss_pv, 780 | avg_val_loss, 781 | ) = Trainer_E_Pk_Pv.evaluate_full_dataloader( 782 | val_dataloader, 783 | model, 784 | added_ent_type_tokens, 785 | added_pk_tokens, 786 | loss_mode, 787 | device=device, 788 | ) 789 | 790 | # report(epoch=0, validation_loss=avg_val_loss) 791 | print(f"Epoch: {0}") 792 | print( 793 | f"Train: loss_ent: {avg_train_loss_ent}, loss_pk: {avg_train_loss_pk}, loss_pv: {avg_train_loss_pv}. loss: {avg_train_loss}" 794 | ) 795 | print( 796 | f"Val: loss_ent: {avg_val_loss_ent}, loss_pk: {avg_val_loss_pk}, loss_pv: {avg_val_loss_pv}. loss: {avg_val_loss}" 797 | ) 798 | 799 | if log_wandb: 800 | wandb.log({"Epoch": 0}) 801 | wandb.log( 802 | { 803 | "train_loss_ent": avg_train_loss_ent, 804 | "train_loss_pk": avg_train_loss_pk, 805 | "train_loss_pv": avg_train_loss_pv, 806 | "train_loss": avg_train_loss, 807 | } 808 | ) 809 | wandb.log( 810 | { 811 | "val_loss_ent": avg_val_loss_ent, 812 | "val_loss_pk": avg_val_loss_pk, 813 | "val_loss_pv": avg_val_loss_pv, 814 | "val_loss": avg_val_loss, 815 | } 816 | ) 817 | 818 | # num_step = 0 819 | for epoch in range(epochs): 820 | # if use_lora: 821 | # print(f"t5_model.shared.original_module", model.t5_model.shared.original_module.weight.sum()) 822 | # print(f"t5_model.shared.modules_to_save", model.t5_model.shared.modules_to_save['default'].weight.sum()) 823 | start_time = time.time() 824 | model.train() 825 | for batch in train_dataloader: 826 | optimizer.zero_grad() 827 | ( 828 | loss_ent, 829 | loss_pk, 830 | loss_pv, 831 | ) = Trainer_E_Pk_Pv.compute_batch_loss( 832 | batch, model, added_ent_type_tokens, added_pk_tokens, loss_mode 833 | ) 834 | 835 | # batch_loss = loss_ent * alpha + loss_pk * (1 - alpha) / 2 + loss_pv * (1 - alpha) / 2 836 | batch_loss = loss_ent + loss_pk + loss_pv 837 | loss = batch_loss 838 | loss.backward() 839 | optimizer.step() 840 | scheduler.step() 841 | 842 | # Compute loss 843 | print("Monitor loss at epoch...") 844 | ( 845 | avg_train_loss_ent, 846 | avg_train_loss_pk, 847 | avg_train_loss_pv, 848 | avg_train_loss, 849 | ) = Trainer_E_Pk_Pv.evaluate_full_dataloader( 850 | train_dataloader, 851 | model, 852 | added_ent_type_tokens, 853 | added_pk_tokens, 854 | loss_mode, 855 | device=device, 856 | ) 857 | ( 858 | avg_val_loss_ent, 859 | avg_val_loss_pk, 860 | avg_val_loss_pv, 861 | avg_val_loss, 862 | ) = Trainer_E_Pk_Pv.evaluate_full_dataloader( 863 | val_dataloader, 864 | model, 865 | added_ent_type_tokens, 866 | added_pk_tokens, 867 | loss_mode, 868 | device=device, 869 | ) 870 | 871 | print(f"Epoch: {epoch + 1}") 872 | print( 873 | f"Train: loss_ent: {avg_train_loss_ent}, loss_pk: {avg_train_loss_pk}, loss_pv: {avg_train_loss_pv}. loss: {avg_train_loss}" 874 | ) 875 | print( 876 | f"Val: loss_ent: {avg_val_loss_ent}, loss_pk: {avg_val_loss_pk}, loss_pv: {avg_val_loss_pv}. loss: {avg_val_loss}" 877 | ) 878 | report(epoch=epoch + 1, validation_loss=avg_val_loss) 879 | 880 | if log_wandb: 881 | wandb.log({"Epoch": epoch + 1}) 882 | wandb.log( 883 | { 884 | "train_loss_ent": avg_train_loss_ent, 885 | "train_loss_pk": avg_train_loss_pk, 886 | "train_loss_pv": avg_train_loss_pv, 887 | "train_loss": avg_train_loss, 888 | } 889 | ) 890 | wandb.log( 891 | { 892 | "val_loss_ent": avg_val_loss_ent, 893 | "val_loss_pk": avg_val_loss_pk, 894 | "val_loss_pv": avg_val_loss_pv, 895 | "val_loss": avg_val_loss, 896 | } 897 | ) 898 | 899 | # Check for early stopping 900 | if avg_val_loss < min_val_loss: 901 | print(f"Save model... (epoch)") 902 | no_improve_epochs = 0 903 | min_val_loss = avg_val_loss 904 | os.makedirs(os.path.dirname(save_path), exist_ok=True) 905 | 906 | if use_lora: 907 | model_path = f"{save_path}" 908 | model.t5_model.shared.modules_to_save["default"] = ( 909 | model.t5_model.shared.original_module 910 | ) 911 | model.save_pretrained(model_path) 912 | # model = model.merge_and_unload() 913 | else: 914 | model_path = f"{save_path}.pt" 915 | torch.save(model.state_dict(), model_path) 916 | else: 917 | no_improve_epochs += 1 918 | 919 | if no_improve_epochs == 10: 920 | print(f"Early stopping at epoch {epoch + 1}") 921 | break 922 | print("Time for epoch:", time.time() - start_time) 923 | # if log_wandb: 924 | # run.finish() 925 | 926 | def evaluate(self, model, test_dataloader, tokenizer, **kwargs): 927 | pass 928 | 929 | @staticmethod 930 | def generate_full_json_output( 931 | model, 932 | dataloader, 933 | added_ent_type_tokens, 934 | added_pk_tokens, 935 | tokenizer, 936 | device, 937 | mode, 938 | ): 939 | 940 | def extract_elements(given_list): 941 | extracted = [] 942 | for item in given_list: 943 | elements = item.split("") 944 | for element in elements: 945 | cleaned_element = element.strip() 946 | if cleaned_element and not cleaned_element.startswith("<"): 947 | extracted.append(cleaned_element) 948 | return extracted 949 | 950 | model.eval() 951 | results = [] 952 | 953 | text_index = 0 954 | 955 | with torch.no_grad(): 956 | for batch in dataloader: 957 | start_time = time.time() 958 | input_ids = batch["input_ids"].to(device) # (b, seq_len) 959 | # labels_ent = batch["labels_ent"].to(device) # (b, max_num_Entity * 2) 960 | # labels_ent_tokenized = batch["labels_ent_tokenized"].to(device) # (b, max_num_Entity * 3) 961 | real_labels_ent_name = batch[ 962 | "real_labels_ent_name" 963 | ] # (b, max_num_Entity * 6) 964 | labels_ent_name = batch["labels_ent_name"].to( 965 | device 966 | ) # (b, max_num_Entity * 6) 967 | labels_pk = batch["labels_pk"].to( 968 | device 969 | ) # (b, max_num_Entity, num_all_pks+2) 970 | labels_pv = batch["labels_pv"].to( 971 | device 972 | ) # (b, max_num_Entity, num_all_pks, max_prop_len) 973 | attention_mask = batch["attention_mask"].to(device) 974 | # attention_mask_ent = batch["attention_mask_ent"].to(device) 975 | # attention_mask_ent_tokenized = batch["attention_mask_ent_tokenized"].to(device) 976 | attention_mask_ent_name = batch["attention_mask_ent_name"].to(device) 977 | attention_mask_pk = batch["attention_mask_pk"].to(device) 978 | attention_mask_pv = batch["attention_mask_pv"].to(device) 979 | max_len_pv = attention_mask_pv.shape[-1] 980 | 981 | predict_ent_ids, predict_pk_ids, predict_pv_ids = model( 982 | input_ids, # (b, seq_len) 983 | labels_ent_name, # (b, max_num_Entity * 6) 984 | real_labels_ent_name, # (b, max_num_Entity * 6) 985 | labels_pk, # (b, max_num_Entity, num_all_pks+2) 986 | labels_pv, # (b, max_num_Entity, num_all_pks, max_prop_len) 987 | attention_mask, 988 | attention_mask_ent_name, 989 | attention_mask_pk, 990 | attention_mask_pv, 991 | max_len_pv, 992 | device, 993 | added_ent_type_tokens, 994 | added_pk_tokens, 995 | mode="test", 996 | ) 997 | 998 | # print("labels_pv:", labels_pv.shape) 999 | # print("predict_ent_ids:", predict_ent_ids.shape, predict_ent_ids) 1000 | # print("predict_pk_ids:", predict_pk_ids.shape, predict_pk_ids) 1001 | # print("predict_pv_ids:", predict_pv_ids.shape, predict_pv_ids) 1002 | 1003 | # predict_pk_ids = replace_with_closest_embedding(predict_pk_ids, added_ent_type_tokens, added_pk_tokens, 1004 | # model, device) 1005 | 1006 | # print("predict_pk_ids:", predict_pk_ids) 1007 | 1008 | """Format prediction""" 1009 | predict_ent_tokens = [ 1010 | tokenizer.decode(ids, skip_special_tokens=False) 1011 | for ids in predict_ent_ids 1012 | ] 1013 | res_predict_ent_name = extract_elements(predict_ent_tokens) 1014 | # predict_pk_tokens = [tokenizer.decode(ids, skip_special_tokens=True) for ids in predict_pk_ids] 1015 | # predict_pv_tokens = [tokenizer.decode(ids, skip_special_tokens=True) for ids in predict_pv_ids] 1016 | 1017 | # Calculate the number of properties for each entity, subtracting 1 for the type 1018 | num_properties_per_entity = [ 1019 | (pk_ids != 1).sum().item() - 1 for pk_ids in predict_pk_ids 1020 | ] 1021 | 1022 | # Decode predict_pk_ids 1023 | predict_pk_tokens = [ 1024 | tokenizer.decode(pk_ids, skip_special_tokens=True) 1025 | for pk_ids in predict_pk_ids 1026 | ] 1027 | 1028 | # Initialize variables 1029 | predict_pv_tokens = [] 1030 | start_index = 0 1031 | 1032 | # Process each entity's property values 1033 | for num_props in num_properties_per_entity: 1034 | # Extract the property values for this entity 1035 | entity_pv_ids = predict_pv_ids[ 1036 | start_index : start_index + num_props 1037 | ] 1038 | 1039 | # Decode and handle value 1 as empty string 1040 | entity_pv_tokens = [ 1041 | ( 1042 | tokenizer.decode(pv_ids, skip_special_tokens=True) 1043 | if pv_ids[0] != 1 1044 | else "" 1045 | ) 1046 | for pv_ids in entity_pv_ids 1047 | ] 1048 | 1049 | # Append to the main list 1050 | predict_pv_tokens.append(entity_pv_tokens) 1051 | 1052 | # Update the start index for the next entity 1053 | start_index += num_props 1054 | 1055 | """Format ground truth""" 1056 | res_real_ent_name = extract_elements(real_labels_ent_name) 1057 | 1058 | labels_pk = labels_pk[(labels_pk != 0).any(dim=-1)] 1059 | labels_pk_flat = labels_pk.view( 1060 | -1, labels_pk.size(-1) 1061 | ) # Flattening to 2D 1062 | res_real_pk = [ 1063 | tokenizer.decode(ids, skip_special_tokens=True) 1064 | for ids in labels_pk_flat 1065 | ] 1066 | 1067 | res_real_pv = [] 1068 | for block in labels_pv[0]: # Assuming the first dimension is always 1 1069 | # Filter out zero rows 1070 | filtered_block = block[(block != 0).any(dim=-1)] 1071 | 1072 | # Check if the filtered block is not empty 1073 | if filtered_block.size(0) != 0: 1074 | # Decode each row in the filtered block 1075 | decoded_rows = [ 1076 | tokenizer.decode(row, skip_special_tokens=True) 1077 | for row in filtered_block 1078 | ] 1079 | res_real_pv.append(decoded_rows) 1080 | 1081 | # Append the results of this batch to the 'batches' key in the results dictionary 1082 | results.append( 1083 | { 1084 | "predict_ent": res_predict_ent_name, 1085 | "predict_pk": predict_pk_tokens, 1086 | "predict_pv": predict_pv_tokens, 1087 | } 1088 | ) 1089 | 1090 | # Optional: Print the current batch results 1091 | print("Batch", text_index) 1092 | print("truth_ent:", res_real_ent_name) 1093 | print("truth_pk:", res_real_pk) 1094 | print("truth_pv:", res_real_pv) 1095 | print() 1096 | 1097 | print("predict_ent:", res_predict_ent_name) 1098 | print("predict_pk:", predict_pk_tokens) 1099 | print("predict_pv:", predict_pv_tokens) 1100 | print("---------------------") 1101 | text_index += 1 1102 | 1103 | return results 1104 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import torch 6 | import torch.backends.cudnn as cudnn 7 | import torch.nn as nn 8 | from peft import LoraConfig, PrefixTuningConfig, TaskType, get_peft_model 9 | from scipy.sparse import csr_matrix 10 | from scipy.sparse.csgraph import min_weight_full_bipartite_matching 11 | from torch.nn import MultiheadAttention 12 | from transformers import (AutoModel, AutoModelForSeq2SeqLM, AutoTokenizer, 13 | DataCollatorWithPadding, GPT2LMHeadModel, 14 | GPT2Tokenizer, LlamaForCausalLM, LlamaTokenizer, 15 | OpenAIGPTLMHeadModel, OpenAIGPTTokenizer, 16 | OpenLlamaConfig, OpenLlamaForCausalLM, 17 | T5ForConditionalGeneration, T5Tokenizer, Trainer, 18 | TrainingArguments, TransfoXLLMHeadModel, 19 | TransfoXLTokenizer) 20 | from transformers.models.t5.modeling_t5 import T5Attention 21 | 22 | 23 | def set_seed(seed): 24 | torch.cuda.manual_seed(seed) 25 | torch.manual_seed(seed) 26 | np.random.seed(seed) 27 | random.seed(seed) 28 | cudnn.deterministic = True 29 | cudnn.benchmark = False 30 | 31 | 32 | def plot_property_stats( 33 | property_accuracies, property_counts, fig_size=(10, 5), path_name=None 34 | ): 35 | # Create figure and axis 36 | fig, ax1 = plt.subplots(figsize=fig_size) 37 | 38 | # Sort property_counts by value, high to low 39 | sorted_property_counts = { 40 | k: v 41 | for k, v in sorted( 42 | property_counts.items(), key=lambda item: item[1], reverse=True 43 | ) 44 | } 45 | 46 | # Bar plot with property counts 47 | ax1.bar( 48 | sorted_property_counts.keys(), 49 | sorted_property_counts.values(), 50 | color="b", 51 | alpha=0.5, 52 | ) 53 | ax1.set_xlabel("Property Name") 54 | ax1.set_ylabel("Counts", color="b") 55 | ax1.tick_params(axis="y", labelcolor="b") 56 | 57 | # Rotate x-labels 90 degrees 58 | plt.xticks(rotation=90) 59 | 60 | # Create a second y-axis that shares the same x-axis, we already handled the x-label with ax1 61 | ax2 = ax1.twinx() 62 | 63 | # Line plot with property accuracies 64 | # Ensure the properties in sorted_property_accuracies follow the same order in sorted_property_counts 65 | sorted_property_accuracies = { 66 | k: property_accuracies[k] for k in sorted_property_counts.keys() 67 | } 68 | ax2.plot( 69 | sorted_property_accuracies.keys(), 70 | sorted_property_accuracies.values(), 71 | color="r", 72 | ) 73 | ax2.set_ylabel("Accuracy", color="r") # we already handled the x-label with ax1 74 | ax2.tick_params(axis="y", labelcolor="r") 75 | 76 | # Layout 77 | fig.tight_layout() 78 | 79 | # Save the figure if a path is provided 80 | if path_name is not None: 81 | plt.savefig(path_name, dpi=300, bbox_inches="tight") 82 | 83 | plt.show() 84 | 85 | 86 | def get_generative_model_and_tokenizer(config): 87 | if config.saved_model_path: 88 | print("Loading pretrained model at", config.saved_model_path) 89 | 90 | if config.generative_model in ("gpt2", "gpt2-large"): 91 | model_path = config.saved_model_path or config.generative_model 92 | kwargs = { 93 | "pretrained_model_name_or_path": model_path, 94 | # "device_map": 'auto', 95 | } 96 | if hasattr(config, "torch_dtype"): 97 | if config.torch_dtype == "float16": 98 | kwargs["torch_dtype"] = torch.float16 99 | elif config.torch_dtype != "float32": 100 | raise ValueError( 101 | f"torch_dtype: {config.torch_dtype} not recognized in config file." 102 | ) 103 | tokenizer = GPT2Tokenizer.from_pretrained(config.generative_model) 104 | model = GPT2LMHeadModel.from_pretrained(**kwargs) 105 | tokenizer.pad_token = tokenizer.eos_token 106 | elif config.generative_model == "custom": 107 | config = OpenLlamaConfig( 108 | vocab_size=32000, 109 | hidden_size=config.hidden_size, 110 | intermediate_size=config.intermediate_size, 111 | num_hidden_layers=config.num_hidden_layers, 112 | max_position_embeddings=config.max_position_embeddings, 113 | ) 114 | model = OpenLlamaForCausalLM(config=config) 115 | tokenizer = LlamaTokenizer.from_pretrained("openlm-research/open_llama_3b_v2") 116 | tokenizer.pad_token_id = 0 117 | elif config.generative_model == "llama_3B": 118 | model_path = config.saved_model_path or "openlm-research/open_llama_3b_v2" 119 | tokenizer = AutoTokenizer.from_pretrained(model_path) 120 | kwargs = { 121 | "pretrained_model_name_or_path": model_path, 122 | # "device_map": 'auto', 123 | } 124 | if hasattr(config, "torch_dtype"): 125 | if config.torch_dtype == "float16": 126 | kwargs["torch_dtype"] = torch.float16 127 | elif config.torch_dtype != "float32": 128 | raise ValueError( 129 | f"torch_dtype: {config.torch_dtype} not recognized in config file." 130 | ) 131 | model = LlamaForCausalLM.from_pretrained(**kwargs) 132 | tokenizer.pad_token = tokenizer.eos_token 133 | # model.tie_weights() 134 | elif "flan-t5" in config.generative_model: 135 | model_path = "google/" + config.generative_model 136 | kwargs = { 137 | "pretrained_model_name_or_path": model_path, 138 | # "device_map": 'auto', 139 | } 140 | if hasattr(config, "torch_dtype"): 141 | if config.torch_dtype == "float16": 142 | kwargs["torch_dtype"] = torch.float16 143 | elif config.torch_dtype != "float32": 144 | raise ValueError( 145 | f"torch_dtype: {config.torch_dtype} not recognized in config file." 146 | ) 147 | model = AutoModelForSeq2SeqLM.from_pretrained(model_path) 148 | tokenizer = AutoTokenizer.from_pretrained(model_path) 149 | elif config.generative_model in ["t5-small", "t5-base", "t5-large"]: 150 | model_path = config.saved_model_path or config.generative_model 151 | kwargs = { 152 | "pretrained_model_name_or_path": model_path, 153 | } 154 | if hasattr(config, "torch_dtype"): 155 | if config.torch_dtype == "float16": 156 | kwargs["torch_dtype"] = torch.float16 157 | elif config.torch_dtype != "float32": 158 | raise ValueError( 159 | f"torch_dtype: {config.torch_dtype} not recognized in config file." 160 | ) 161 | # model = AutoModelForSeq2SeqLM.from_pretrained(model_path) 162 | # tokenizer = AutoTokenizer.from_pretrained(model_path) 163 | model = T5ForConditionalGeneration.from_pretrained(config.generative_model) 164 | # print("T5 model", model) 165 | tokenizer = T5Tokenizer.from_pretrained(config.generative_model) 166 | # Add the new tokens to the tokenizer 167 | new_tokens = ["", "", "", "{", "}"] 168 | tokenizer.add_tokens(new_tokens) 169 | else: 170 | raise ValueError(f'model name "{config.generative_model}" not recognized') 171 | 172 | # Apply LoRa training if config.use_lora is True 173 | if config.use_lora: 174 | lora_config = LoraConfig( 175 | r=config.lora_r, 176 | lora_alpha=config.lora_alpha, 177 | lora_dropout=config.lora_dropout, 178 | task_type=( 179 | TaskType.SEQ_2_SEQ_LM 180 | if "flan-t5" in config.generative_model 181 | else TaskType.CAUSAL_LM 182 | ), 183 | target_modules=config.lora_target_modules, 184 | modules_to_save=config.lora_modules_to_save, 185 | ) 186 | model = get_peft_model(model, lora_config) 187 | model.print_trainable_parameters() 188 | 189 | return model, tokenizer 190 | 191 | 192 | def compute_inverse_frequency_weights(entity_type_counts, num_entity_types): 193 | # Extract counts 194 | counts = list(entity_type_counts.values()) 195 | # Compute inverse frequency 196 | inverse_freq = [1.0 / count for count in counts] 197 | # Normalize (optional, but it helps in cases where you'd want the weights to be relative to the highest class weight) 198 | total = sum(inverse_freq) 199 | normalized_weights = [freq / total for freq in inverse_freq] 200 | 201 | return torch.tensor(normalized_weights, dtype=torch.float32) 202 | 203 | 204 | def print_trainable_parameters(model): 205 | trainable_params = 0 206 | all_param = 0 207 | for _, param in model.named_parameters(): 208 | all_param += param.numel() 209 | if param.requires_grad: 210 | trainable_params += param.numel() 211 | print( 212 | f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param:.2f}" 213 | ) 214 | 215 | 216 | def get_attention_paths(model, path=""): 217 | paths = [] 218 | for name, module in model.named_children(): 219 | new_path = f"{path}.{name}" if path else name 220 | 221 | if isinstance(module, (T5Attention)): 222 | paths.append(f"{new_path}.q") 223 | # paths.append(f"{new_path}.k") 224 | paths.append(f"{new_path}.v") 225 | # paths.append(f"{new_path}.o") 226 | else: 227 | paths.extend(get_attention_paths(module, new_path)) 228 | 229 | return paths 230 | 231 | 232 | def get_transformerlayer_paths(model, path=""): 233 | paths = [] 234 | for name, module in model.named_children(): 235 | new_path = f"{path}.{name}" if path else name 236 | 237 | if isinstance(module, nn.TransformerEncoderLayer): 238 | paths.append(new_path) 239 | else: 240 | paths.extend(get_transformerlayer_paths(module, new_path)) 241 | 242 | return paths 243 | 244 | 245 | def remove_duplicates_and_postprocess(entity_lst): 246 | def postprocess(entity): 247 | for key, value in entity.items(): 248 | if not isinstance(value, str): 249 | if isinstance(value, list): 250 | try: 251 | value = " ".join(value) 252 | except: 253 | value = str(value) 254 | else: 255 | value = str(value) 256 | entity[key] = value 257 | return entity 258 | 259 | new_lst = [] 260 | for e in entity_lst: 261 | e = postprocess(e) 262 | if not e in new_lst: 263 | new_lst.append(e) 264 | return new_lst 265 | --------------------------------------------------------------------------------