├── .github
└── dependabot.yml
├── .gitignore
├── CODE_OF_CONDUCT.md
├── LICENSE
├── README.md
├── SECURITY.md
├── SUPPORT.md
├── args.py
├── compute_metrics.py
├── data
├── __init__.py
└── dataloader_musee.py
├── experiment_musee.py
├── img
├── metric.png
└── model.png
├── metric.py
├── model
├── __init__.py
├── t5_with_t5decoder.py
├── t5_with_t5decoder_emb.py
└── utils
│ ├── __init__.py
│ ├── dist_util.py
│ ├── fp16_util.py
│ ├── logger.py
│ ├── losses.py
│ └── nn.py
├── requirements.txt
├── trainer
├── __init__.py
└── trainer_musee.py
└── utils.py
/.github/dependabot.yml:
--------------------------------------------------------------------------------
1 | # To get started with Dependabot version updates, you'll need to specify which
2 | # package ecosystems to update and where the package manifests are located.
3 | # Please see the documentation for all configuration options:
4 | # https://docs.github.com/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file
5 |
6 | version: 2
7 | updates:
8 | - package-ecosystem: "pip"
9 | directory: "/"
10 | schedule:
11 | interval: "weekly"
12 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | ## Ignore Visual Studio temporary files, build results, and
2 | ## files generated by popular Visual Studio add-ons.
3 | ##
4 | ## Get latest from https://github.com/github/gitignore/blob/main/VisualStudio.gitignore
5 |
6 | # User-specific files
7 | *.rsuser
8 | *.suo
9 | *.user
10 | *.userosscache
11 | *.sln.docstates
12 |
13 | # User-specific files (MonoDevelop/Xamarin Studio)
14 | *.userprefs
15 |
16 | # Mono auto generated files
17 | mono_crash.*
18 |
19 | # Build results
20 | [Dd]ebug/
21 | [Dd]ebugPublic/
22 | [Rr]elease/
23 | [Rr]eleases/
24 | x64/
25 | x86/
26 | [Ww][Ii][Nn]32/
27 | [Aa][Rr][Mm]/
28 | [Aa][Rr][Mm]64/
29 | bld/
30 | [Bb]in/
31 | [Oo]bj/
32 | [Ll]og/
33 | [Ll]ogs/
34 |
35 | # Visual Studio 2015/2017 cache/options directory
36 | .vs/
37 | # Uncomment if you have tasks that create the project's static files in wwwroot
38 | #wwwroot/
39 |
40 | # Visual Studio 2017 auto generated files
41 | Generated\ Files/
42 |
43 | # MSTest test Results
44 | [Tt]est[Rr]esult*/
45 | [Bb]uild[Ll]og.*
46 |
47 | # NUnit
48 | *.VisualState.xml
49 | TestResult.xml
50 | nunit-*.xml
51 |
52 | # Build Results of an ATL Project
53 | [Dd]ebugPS/
54 | [Rr]eleasePS/
55 | dlldata.c
56 |
57 | # Benchmark Results
58 | BenchmarkDotNet.Artifacts/
59 |
60 | # .NET Core
61 | project.lock.json
62 | project.fragment.lock.json
63 | artifacts/
64 |
65 | # ASP.NET Scaffolding
66 | ScaffoldingReadMe.txt
67 |
68 | # StyleCop
69 | StyleCopReport.xml
70 |
71 | # Files built by Visual Studio
72 | *_i.c
73 | *_p.c
74 | *_h.h
75 | *.ilk
76 | *.meta
77 | *.obj
78 | *.iobj
79 | *.pch
80 | *.pdb
81 | *.ipdb
82 | *.pgc
83 | *.pgd
84 | *.rsp
85 | *.sbr
86 | *.tlb
87 | *.tli
88 | *.tlh
89 | *.tmp
90 | *.tmp_proj
91 | *_wpftmp.csproj
92 | *.log
93 | *.tlog
94 | *.vspscc
95 | *.vssscc
96 | .builds
97 | *.pidb
98 | *.svclog
99 | *.scc
100 |
101 | # Chutzpah Test files
102 | _Chutzpah*
103 |
104 | # Visual C++ cache files
105 | ipch/
106 | *.aps
107 | *.ncb
108 | *.opendb
109 | *.opensdf
110 | *.sdf
111 | *.cachefile
112 | *.VC.db
113 | *.VC.VC.opendb
114 |
115 | # Visual Studio profiler
116 | *.psess
117 | *.vsp
118 | *.vspx
119 | *.sap
120 |
121 | # Visual Studio Trace Files
122 | *.e2e
123 |
124 | # TFS 2012 Local Workspace
125 | $tf/
126 |
127 | # Guidance Automation Toolkit
128 | *.gpState
129 |
130 | # ReSharper is a .NET coding add-in
131 | _ReSharper*/
132 | *.[Rr]e[Ss]harper
133 | *.DotSettings.user
134 |
135 | # TeamCity is a build add-in
136 | _TeamCity*
137 |
138 | # DotCover is a Code Coverage Tool
139 | *.dotCover
140 |
141 | # AxoCover is a Code Coverage Tool
142 | .axoCover/*
143 | !.axoCover/settings.json
144 |
145 | # Coverlet is a free, cross platform Code Coverage Tool
146 | coverage*.json
147 | coverage*.xml
148 | coverage*.info
149 |
150 | # Visual Studio code coverage results
151 | *.coverage
152 | *.coveragexml
153 |
154 | # NCrunch
155 | _NCrunch_*
156 | .*crunch*.local.xml
157 | nCrunchTemp_*
158 |
159 | # MightyMoose
160 | *.mm.*
161 | AutoTest.Net/
162 |
163 | # Web workbench (sass)
164 | .sass-cache/
165 |
166 | # Installshield output folder
167 | [Ee]xpress/
168 |
169 | # DocProject is a documentation generator add-in
170 | DocProject/buildhelp/
171 | DocProject/Help/*.HxT
172 | DocProject/Help/*.HxC
173 | DocProject/Help/*.hhc
174 | DocProject/Help/*.hhk
175 | DocProject/Help/*.hhp
176 | DocProject/Help/Html2
177 | DocProject/Help/html
178 |
179 | # Click-Once directory
180 | publish/
181 |
182 | # Publish Web Output
183 | *.[Pp]ublish.xml
184 | *.azurePubxml
185 | # Note: Comment the next line if you want to checkin your web deploy settings,
186 | # but database connection strings (with potential passwords) will be unencrypted
187 | *.pubxml
188 | *.publishproj
189 |
190 | # Microsoft Azure Web App publish settings. Comment the next line if you want to
191 | # checkin your Azure Web App publish settings, but sensitive information contained
192 | # in these scripts will be unencrypted
193 | PublishScripts/
194 |
195 | # NuGet Packages
196 | *.nupkg
197 | # NuGet Symbol Packages
198 | *.snupkg
199 | # The packages folder can be ignored because of Package Restore
200 | **/[Pp]ackages/*
201 | # except build/, which is used as an MSBuild target.
202 | !**/[Pp]ackages/build/
203 | # Uncomment if necessary however generally it will be regenerated when needed
204 | #!**/[Pp]ackages/repositories.config
205 | # NuGet v3's project.json files produces more ignorable files
206 | *.nuget.props
207 | *.nuget.targets
208 |
209 | # Microsoft Azure Build Output
210 | csx/
211 | *.build.csdef
212 |
213 | # Microsoft Azure Emulator
214 | ecf/
215 | rcf/
216 |
217 | # Windows Store app package directories and files
218 | AppPackages/
219 | BundleArtifacts/
220 | Package.StoreAssociation.xml
221 | _pkginfo.txt
222 | *.appx
223 | *.appxbundle
224 | *.appxupload
225 |
226 | # Visual Studio cache files
227 | # files ending in .cache can be ignored
228 | *.[Cc]ache
229 | # but keep track of directories ending in .cache
230 | !?*.[Cc]ache/
231 |
232 | # Others
233 | ClientBin/
234 | ~$*
235 | *~
236 | *.dbmdl
237 | *.dbproj.schemaview
238 | *.jfm
239 | *.pfx
240 | *.publishsettings
241 | orleans.codegen.cs
242 |
243 | # Including strong name files can present a security risk
244 | # (https://github.com/github/gitignore/pull/2483#issue-259490424)
245 | #*.snk
246 |
247 | # Since there are multiple workflows, uncomment next line to ignore bower_components
248 | # (https://github.com/github/gitignore/pull/1529#issuecomment-104372622)
249 | #bower_components/
250 |
251 | # RIA/Silverlight projects
252 | Generated_Code/
253 |
254 | # Backup & report files from converting an old project file
255 | # to a newer Visual Studio version. Backup files are not needed,
256 | # because we have git ;-)
257 | _UpgradeReport_Files/
258 | Backup*/
259 | UpgradeLog*.XML
260 | UpgradeLog*.htm
261 | ServiceFabricBackup/
262 | *.rptproj.bak
263 |
264 | # SQL Server files
265 | *.mdf
266 | *.ldf
267 | *.ndf
268 |
269 | # Business Intelligence projects
270 | *.rdl.data
271 | *.bim.layout
272 | *.bim_*.settings
273 | *.rptproj.rsuser
274 | *- [Bb]ackup.rdl
275 | *- [Bb]ackup ([0-9]).rdl
276 | *- [Bb]ackup ([0-9][0-9]).rdl
277 |
278 | # Microsoft Fakes
279 | FakesAssemblies/
280 |
281 | # GhostDoc plugin setting file
282 | *.GhostDoc.xml
283 |
284 | # Node.js Tools for Visual Studio
285 | .ntvs_analysis.dat
286 | node_modules/
287 |
288 | # Visual Studio 6 build log
289 | *.plg
290 |
291 | # Visual Studio 6 workspace options file
292 | *.opt
293 |
294 | # Visual Studio 6 auto-generated workspace file (contains which files were open etc.)
295 | *.vbw
296 |
297 | # Visual Studio 6 auto-generated project file (contains which files were open etc.)
298 | *.vbp
299 |
300 | # Visual Studio 6 workspace and project file (working project files containing files to include in project)
301 | *.dsw
302 | *.dsp
303 |
304 | # Visual Studio 6 technical files
305 | *.ncb
306 | *.aps
307 |
308 | # Visual Studio LightSwitch build output
309 | **/*.HTMLClient/GeneratedArtifacts
310 | **/*.DesktopClient/GeneratedArtifacts
311 | **/*.DesktopClient/ModelManifest.xml
312 | **/*.Server/GeneratedArtifacts
313 | **/*.Server/ModelManifest.xml
314 | _Pvt_Extensions
315 |
316 | # Paket dependency manager
317 | .paket/paket.exe
318 | paket-files/
319 |
320 | # FAKE - F# Make
321 | .fake/
322 |
323 | # CodeRush personal settings
324 | .cr/personal
325 |
326 | # Python Tools for Visual Studio (PTVS)
327 | __pycache__/
328 | *.pyc
329 |
330 | # Cake - Uncomment if you are using it
331 | # tools/**
332 | # !tools/packages.config
333 |
334 | # Tabs Studio
335 | *.tss
336 |
337 | # Telerik's JustMock configuration file
338 | *.jmconfig
339 |
340 | # BizTalk build output
341 | *.btp.cs
342 | *.btm.cs
343 | *.odx.cs
344 | *.xsd.cs
345 |
346 | # OpenCover UI analysis results
347 | OpenCover/
348 |
349 | # Azure Stream Analytics local run output
350 | ASALocalRun/
351 |
352 | # MSBuild Binary and Structured Log
353 | *.binlog
354 |
355 | # NVidia Nsight GPU debugger configuration file
356 | *.nvuser
357 |
358 | # MFractors (Xamarin productivity tool) working folder
359 | .mfractor/
360 |
361 | # Local History for Visual Studio
362 | .localhistory/
363 |
364 | # Visual Studio History (VSHistory) files
365 | .vshistory/
366 |
367 | # BeatPulse healthcheck temp database
368 | healthchecksdb
369 |
370 | # Backup folder for Package Reference Convert tool in Visual Studio 2017
371 | MigrationBackup/
372 |
373 | # Ionide (cross platform F# VS Code tools) working folder
374 | .ionide/
375 |
376 | # Fody - auto-generated XML schema
377 | FodyWeavers.xsd
378 |
379 | # VS Code files for those working on multiple tools
380 | .vscode/*
381 | !.vscode/settings.json
382 | !.vscode/tasks.json
383 | !.vscode/launch.json
384 | !.vscode/extensions.json
385 | *.code-workspace
386 |
387 | # Local History for Visual Studio Code
388 | .history/
389 |
390 | # Windows Installer files from build outputs
391 | *.cab
392 | *.msi
393 | *.msix
394 | *.msm
395 | *.msp
396 |
397 | # JetBrains Rider
398 | *.sln.iml
399 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Microsoft Open Source Code of Conduct
2 |
3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
4 |
5 | Resources:
6 |
7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns
10 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) Microsoft Corporation.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Learning to Extract Structured Entities Using Language Models
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 | [](https://github.com/microsoft/Structured-Entity-Extraction/actions/workflows/github-code-scanning/codeql)
10 |
11 | [](https://github.com/microsoft/Structured-Entity-Extraction/actions/workflows/dependabot/dependabot-updates)
12 |
13 |
14 | **[🔥 Oral, top 7% of all accepted papers 🔥]**
15 |
16 | ⚙️ This is the implementation of our collaboration between **MSR** and **Mila**, ["**Learning to Extract Structured Entities Using Language Models**"](https://arxiv.org/pdf/2402.04437), accepted to **EMNLP 2024 Main conference**.
17 |
18 | ## Abstract
19 | Recent advances in machine learning have significantly impacted the field of information extraction, with Language Models (LMs) playing a pivotal role in extracting structured information from unstructured text. Prior works typically represent information extraction as triplet-centric and use classical metrics such as precision and recall for evaluation. We reformulate the task to be entity-centric, enabling the use of diverse metrics that can provide more insights from various perspectives. We contribute to the field by introducing Structured Entity Extraction and proposing the Approximate Entity Set OverlaP (AESOP) metric, designed to appropriately assess model performance. Later, we introduce a new Multi-stage Structured Entity Extraction (MuSEE) model that harnesses the power of LMs for enhanced effectiveness and efficiency by decomposing the extraction task into multiple stages. Quantitative and human side-by-side evaluations confirm that our model outperforms baselines, offering promising directions for future advancements in structured entity extraction.
20 |
21 |
22 |
23 |
24 | ## Install Dependencies
25 | ```bash
26 | conda create -n MuSEE python=3.8 --file requirements.txt
27 | conda activate MuSEE
28 | ```
29 |
30 | ## Directory structure
31 | ```
32 | data/ # Dataset generation code will be released soon due to internal process to go through.
33 | |-- GPT4-based/ # GPT4-based dataset
34 | |-- Wikidata-based/ # Wikidata-based dataset
35 | |-- nyt/ # New York Times Relation Extraction dataset
36 | |-- conll04/ # CoNLL04 dataset
37 | |-- REBEL/ # REBEL dataset
38 | |-- TREX/ # T-REx dataset
39 | |-- dataloader_musee.py # Dataloader for MuSEE model
40 | model/
41 | |-- t5_with_t5decoder.py # base model architecture for MuSEE
42 | trainer/
43 | |-- trainer_musee.py # Trainer for MuSEE model
44 | args.py # Arguments for MuSEE model and running experiments
45 | experiment_musee.py # Main file to run experiments
46 | metric.py # Calculate different variants of the proposed AESOP metric
47 | compute_metrics.py # Calculate metrics for the entire dataset
48 | requirements.txt # Required packages
49 | utils.py # Utility functions
50 | ```
51 |
52 | ## Run the code
53 | ```
54 | python experiment_musee.py \
55 | --model_choice=musee \
56 | --dataset=gpt4 \
57 | --pretrained_model_name=t5-large \
58 | --batch_size=1 \
59 | --epochs=100 \
60 | --log_wandb=True \
61 | --use_lora=True \
62 | --lr=1e-4 \
63 | --weight_decay=1e-2 \
64 | --mode=train \
65 | --loss_mode=mean \
66 | --use_better_init=True
67 | ```
68 |
69 | ## Citation and Contact
70 | If you find this paper useful, please cite our work:
71 | ```
72 | @inproceedings{wu2024structured,
73 | title={Structured Entity Extraction Using Large Language Models},
74 | author={Haolun Wu, Ye Yuan, Liana Mikaelyan, Alexander Meulemans, Xue Liu, James Hensman, and Bhaskar Mitra},
75 | booktitle = "Proceedings of the 2024 Conference on Empirical Methods in Natural Language Processing",
76 | month = nov,
77 | year = "2024",
78 | address = "Miami, USA",
79 | publisher = "Association for Computational Linguistics",
80 | }
81 | ```
82 |
83 | 💬 If you have any questions, feel free to contact us through email (haolun.wu@mail.mcgill.ca, ye.yuan3@mail.mcgill.ca) or Github issues. Enjoy!
84 |
85 |
86 |
--------------------------------------------------------------------------------
/SECURITY.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Security
4 |
5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet) and [Xamarin](https://github.com/xamarin).
6 |
7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/security.md/definition), please report it to us as described below.
8 |
9 | ## Reporting Security Issues
10 |
11 | **Please do not report security vulnerabilities through public GitHub issues.**
12 |
13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/security.md/msrc/create-report).
14 |
15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/security.md/msrc/pgp).
16 |
17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc).
18 |
19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
20 |
21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
22 | * Full paths of source file(s) related to the manifestation of the issue
23 | * The location of the affected source code (tag/branch/commit or direct URL)
24 | * Any special configuration required to reproduce the issue
25 | * Step-by-step instructions to reproduce the issue
26 | * Proof-of-concept or exploit code (if possible)
27 | * Impact of the issue, including how an attacker might exploit the issue
28 |
29 | This information will help us triage your report more quickly.
30 |
31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/security.md/msrc/bounty) page for more details about our active programs.
32 |
33 | ## Preferred Languages
34 |
35 | We prefer all communications to be in English.
36 |
37 | ## Policy
38 |
39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/security.md/cvd).
40 |
41 |
42 |
--------------------------------------------------------------------------------
/SUPPORT.md:
--------------------------------------------------------------------------------
1 | # TODO: The maintainer of this repo has not yet edited this file
2 |
3 | **REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project?
4 |
5 | - **No CSS support:** Fill out this template with information about how to file issues and get help.
6 | - **Yes CSS support:** Fill out an intake form at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). CSS will work with/help you to determine next steps.
7 | - **Not sure?** Fill out an intake as though the answer were "Yes". CSS will help you decide.
8 |
9 | *Then remove this first heading from this SUPPORT.MD file before publishing your repo.*
10 |
11 | # Support
12 |
13 | ## How to file issues and get help
14 |
15 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing
16 | issues before filing new issues to avoid duplicates. For new issues, file your bug or
17 | feature request as a new Issue.
18 |
19 | For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE
20 | FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER
21 | CHANNEL. WHERE WILL YOU HELP PEOPLE?**.
22 |
23 | ## Microsoft Support Policy
24 |
25 | Support for this **PROJECT or PRODUCT** is limited to the resources listed above.
26 |
--------------------------------------------------------------------------------
/args.py:
--------------------------------------------------------------------------------
1 | import os
2 | from argparse import ArgumentParser, ArgumentTypeError
3 | from datetime import datetime
4 |
5 |
6 | def str2bool(v):
7 | if isinstance(v, bool):
8 | return v
9 | if v.lower() in ("yes", "true", "t", "y", "1"):
10 | return True
11 | elif v.lower() in ("no", "false", "f", "n", "0"):
12 | return False
13 | else:
14 | raise ArgumentTypeError("Boolean value expected.")
15 |
16 |
17 | def parse_args(): # Parse command line arguments
18 | parser = ArgumentParser(description="mlm_seq")
19 | parser.add_argument(
20 | "--use_data", default=5, type=int, help="The number of datum used"
21 | )
22 | parser.add_argument(
23 | "--batch_size", default=1, type=int, help="Batch size for training"
24 | )
25 | parser.add_argument(
26 | "--epochs", default=500, type=int, help="Number of epochs to train for"
27 | )
28 | parser.add_argument(
29 | "--lr", default=3e-4, type=float, help="Learning rate for the optimizer"
30 | )
31 | parser.add_argument(
32 | "--special_token_lr",
33 | default=1e-2,
34 | type=float,
35 | help="Learning rate for the optimizer",
36 | )
37 | parser.add_argument(
38 | "--use_special_token",
39 | type=str2bool,
40 | nargs="?",
41 | default=True,
42 | help="true denotes using special token",
43 | )
44 | parser.add_argument(
45 | "--use_diff_lr",
46 | type=str,
47 | default="all_embed",
48 | choices=["special", "all_embed"],
49 | help="all_embed denotes using larger learning rate for all embedding layers",
50 | )
51 | parser.add_argument(
52 | "--use_better_init",
53 | type=str2bool,
54 | nargs="?",
55 | default=True,
56 | help="true denotes using better initialization",
57 | )
58 | parser.add_argument(
59 | "--noise_std_dev",
60 | default=1e-3,
61 | type=float,
62 | help="the variance for the random gaussian noise",
63 | )
64 | parser.add_argument(
65 | "--dataset",
66 | default="toy",
67 | type=str,
68 | choices=["internal", "gpt4", "wikidata", "wikidata-hallu", "toy"],
69 | help="Dataset name",
70 | )
71 | parser.add_argument(
72 | "--weight_decay",
73 | default=1e-2,
74 | type=float,
75 | help="Weight Decay for the optimizer",
76 | )
77 | parser.add_argument("--use_lora", type=str2bool, nargs="?", default=True)
78 | parser.add_argument("--alpha", default=0.5, type=float)
79 | parser.add_argument(
80 | "--encoder_lr", default=1e-3, type=float, help="Learning rate for the optimizer"
81 | )
82 | parser.add_argument(
83 | "--decoder_lr", default=1e-3, type=float, help="Learning rate for the optimizer"
84 | )
85 | parser.add_argument(
86 | "--num_fine_tune",
87 | default=100,
88 | type=int,
89 | help="Number of epochs to fine tune the pre-trained BERT",
90 | )
91 | parser.add_argument(
92 | "--ignore_zero",
93 | type=str2bool,
94 | nargs="?",
95 | default=True,
96 | help="True denotes ignoring the padding zeros when calculating the CE loss",
97 | )
98 | parser.add_argument(
99 | "--norm_loss",
100 | type=str2bool,
101 | nargs="?",
102 | default=True,
103 | help="Normalize the CE loss by the number of tokens in each property",
104 | )
105 | parser.add_argument(
106 | "--neg_sample",
107 | type=str2bool,
108 | nargs="?",
109 | default=True,
110 | help="Use negative sampling during predicitng property key existence",
111 | )
112 | parser.add_argument(
113 | "--use_pretrained",
114 | type=str2bool,
115 | nargs="?",
116 | default=True,
117 | help="True denotes using pretrained model from hugging face",
118 | )
119 | parser.add_argument(
120 | "--only_eval",
121 | type=str2bool,
122 | nargs="?",
123 | default=False,
124 | help="Only conduct evaluation",
125 | )
126 | parser.add_argument(
127 | "--log_wandb",
128 | type=str2bool,
129 | nargs="?",
130 | default=False,
131 | help="True denotes using wandb to log the training process",
132 | )
133 | parser.add_argument(
134 | "--pretrained_model_name",
135 | default="t5-small",
136 | type=str,
137 | # choices=["t5-small", "t5-base", "t5-large"],
138 | help="Which base model to use",
139 | )
140 | parser.add_argument(
141 | "--decoder_choice",
142 | default="T5",
143 | type=str,
144 | choices=["BERT", "GPT2", "T5"],
145 | help="Which decoder structure we will use as the decoder",
146 | )
147 | parser.add_argument(
148 | "--num_optimizer",
149 | default="1",
150 | type=str,
151 | choices=["1", "2", "3"],
152 | help="Whether use separate optimizer for encoder and decoder",
153 | )
154 | parser.add_argument(
155 | "--seed",
156 | default=42,
157 | type=int,
158 | help="Seeds helps ",
159 | )
160 | parser.add_argument("--mode", choices=["train", "test"], default="train", type=str)
161 | parser.add_argument("--loss_mode", choices=["sum", "mean"], default="sum", type=str)
162 | parser.add_argument(
163 | "--check_epoch", default=6, type=int, help="check_epoch for evaluation"
164 | )
165 | parser.add_argument("--top_entity", default=20, type=int, help="top_entity")
166 | parser.add_argument("--top_property", default=50, type=int, help="top_property")
167 | parser.add_argument(
168 | "--perturbation_test",
169 | default=True,
170 | type=str2bool,
171 | nargs="?",
172 | help="True denotes using a perturbed subset to test the model",
173 | )
174 | parser.add_argument(
175 | "--perturbation_exp",
176 | default=True,
177 | type=str2bool,
178 | nargs="?",
179 | help="True denotes using the perturbation testing set, "
180 | "and False denotes using the regular testing set",
181 | )
182 | parser.add_argument(
183 | "--training_mode",
184 | default=True,
185 | type=str2bool,
186 | nargs="?",
187 | help="False denotes loading saved model",
188 | )
189 | parser.add_argument(
190 | "--decode_type", default="at", type=str, help="AT vs NAT for decoder"
191 | )
192 | parser.add_argument(
193 | "--model_choice",
194 | default="musee",
195 | choices=[
196 | "Single-mask-Multi-Entity-Step1",
197 | "Single-mask-Multi-Entity-Step2",
198 | "Single-mask-Multi-Entity-M3",
199 | "generative-llm",
200 | "M1",
201 | "E1",
202 | "musee",
203 | ],
204 | type=str,
205 | )
206 | parser.add_argument(
207 | "--generative_model",
208 | type=str,
209 | default="t5-small",
210 | choices=["gpt2", "gpt2-large", "t5-small", "t5-base", "t5-large", "llama_3B"],
211 | help="Which llm model to use for the generative llm modeling",
212 | )
213 | parser.add_argument(
214 | "--start_sentence",
215 | default="\n\nCreate a JSON file containing all named entities in the previous text:\n",
216 | type=str,
217 | )
218 | parser.add_argument("--total_batch_size", default=32, type=int)
219 | parser.add_argument("--gradient_accumulation_steps", default=32, type=int)
220 | parser.add_argument("--adam_beta1", default=0.9, type=float)
221 | parser.add_argument("--adam_beta2", default=0.999, type=float)
222 | parser.add_argument("--adam_epsilon", default=1e-8, type=float)
223 | parser.add_argument("--max_grad_norm", default=0.01, type=float)
224 | parser.add_argument("--lr_scheduler_type", default="linear", type=str)
225 | parser.add_argument("--num_warmup_steps", default=90, type=int)
226 | parser.add_argument("--no_cuda", type=str2bool, nargs="?", default=False)
227 | parser.add_argument("--generate_num_return_sequences", default=1, type=int)
228 | parser.add_argument("--generate_temperature", default=0.7, type=float)
229 | parser.add_argument("--generate_top_k", default=50, type=int)
230 | parser.add_argument("--generate_top_p", default=0.95, type=float)
231 | parser.add_argument("--generate_do_sample", type=str2bool, nargs="?", default=False)
232 | parser.add_argument("--generate_num_beams", default=1, type=int)
233 | parser.add_argument("--evaluation_strategy", default="epoch", type=str)
234 | parser.add_argument("--logging_steps", default=100, type=int)
235 | parser.add_argument("--save_final_model", default=True, type=str2bool, nargs="?")
236 | parser.add_argument("--save_strategy", default="epoch", type=str)
237 | parser.add_argument("--lora_alpha", default=32, type=int)
238 | parser.add_argument("--lora_dropout", default=0.1, type=float)
239 | parser.add_argument("--lora_r", default=16, type=int)
240 | parser.add_argument("--max_length", type=int, default=2048)
241 | parser.add_argument("--output_dir", type=str, default="logs/logs")
242 | parser.add_argument(
243 | "--load_best_model_at_end", type=str2bool, nargs="?", default=False
244 | )
245 | parser.add_argument(
246 | "--torch_dtype", default="float16", type=str, choices=["float16", "float32"]
247 | )
248 | parser.add_argument("--lora_target_modules", default=None)
249 | parser.add_argument("--lora_modules_to_save", default=None)
250 | parser.add_argument("--tight_padding", type=str2bool, nargs="?", default=True)
251 | parser.add_argument("--save_total_limit", type=int, default=2)
252 | parser.add_argument("--saved_model_path", type=str)
253 | parser.add_argument(
254 | "--eval_batch_size",
255 | default=8,
256 | type=int,
257 | help="Batch size for LLM evaluation and output generation",
258 | )
259 | parser.add_argument("--generate_output_path", type=str, default="generation_output")
260 | parser.add_argument("--st_checkpoint_dir", type=str, default="st_checkpoint")
261 |
262 | args = parser.parse_args()
263 | return postprocess_args(args)
264 |
265 |
266 | def postprocess_args(args):
267 | curr_date = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
268 | output_dir = os.path.join(
269 | args.output_dir
270 | + f"_model_{args.model_choice}"
271 | + f"_{args.dataset}"
272 | + f"_lr_{str(args.lr)}"
273 | + f"_wd_{str(args.weight_decay)}"
274 | + f"_alpha_{str(args.alpha)}"
275 | + f"_lora_{str(args.use_lora)}"
276 | + f"_loss_mode_{str(args.loss_mode)}"
277 | + f"_init_{str(args.use_better_init)}"
278 | + f"_{str(args.mode)}",
279 | f"run_{curr_date}",
280 | )
281 | args.out_dir = output_dir
282 | if not os.path.exists(output_dir):
283 | os.makedirs(output_dir)
284 |
285 | return args
286 |
--------------------------------------------------------------------------------
/compute_metrics.py:
--------------------------------------------------------------------------------
1 | import json
2 | from collections import defaultdict
3 |
4 | import numpy as np
5 | from metric import compute_bipartite_matching_metrics
6 |
7 |
8 | def evaluate(ground_truth_path, prediction_path):
9 | target_count = 0
10 | target_type_counts = defaultdict(int)
11 | generated_count = 0
12 | generated_type_counts = defaultdict(int)
13 | target_entities_without_type = 0
14 | generated_entities_without_type = 0
15 |
16 | metrics = []
17 | counts = []
18 | nested_counts = []
19 |
20 | with open(ground_truth_path, "r") as f:
21 | ground_truth = json.load(f)
22 |
23 | with open(prediction_path, "r") as f:
24 | predictions = json.load(f)
25 |
26 | new_dict = {}
27 | for k, v in enumerate(ground_truth.values()):
28 | new_dict[str(k)] = v
29 | ground_truth = new_dict
30 |
31 | new_dict = {}
32 | for k, v in enumerate(predictions.values()):
33 | new_dict[str(k)] = v
34 | predictions = new_dict
35 |
36 | print("Length of ground truth:", len(ground_truth))
37 | print("Length of predictions:", len(predictions))
38 |
39 | # Ensure both files have matching document IDs
40 | assert set(ground_truth.keys()) == set(
41 | predictions.keys()
42 | ), "Mismatch in document IDs between ground truth and predictions."
43 |
44 | exact_name_max = []
45 | exact_name_precision = []
46 | exact_name_recall = []
47 | approx_name_max = []
48 | approx_name_precision = []
49 | approx_name_recall = []
50 | multi_prop_max = []
51 | multi_prop_precision = []
52 | multi_prop_recall = []
53 | for doc_id, doc_data in ground_truth.items():
54 | target_entities = doc_data["entities"]
55 | target_count += len(target_entities)
56 |
57 | for entity in target_entities.values():
58 | try:
59 | target_type_counts[entity["type"]] += 1
60 | except KeyError:
61 | target_entities_without_type += 1
62 |
63 | generated_output = predictions[doc_id]["entities"]
64 |
65 | generated_entities = list(generated_output.values())
66 | generated_count += len(generated_entities)
67 |
68 | for entity in generated_entities:
69 | if "type" in entity:
70 | generated_type_counts[entity["type"]] += 1
71 | else:
72 | generated_entities_without_type += 1
73 |
74 | for normalization in ["Max", "Precision", "Recall"]:
75 | for measures in ["ExactName", "ApproxName", "MultiProp"]:
76 | (
77 | final_metrics,
78 | count,
79 | nested_count,
80 | ) = compute_bipartite_matching_metrics(
81 | target_entities,
82 | generated_entities,
83 | measures=measures,
84 | normalization=normalization,
85 | establish_threshold=0.6,
86 | )
87 | if measures == "ExactName":
88 | if normalization == "Max":
89 | exact_name_max.append(final_metrics["normalized_similarity"])
90 | elif normalization == "Precision":
91 | exact_name_precision.append(
92 | final_metrics["normalized_similarity"]
93 | )
94 | elif normalization == "Recall":
95 | exact_name_recall.append(final_metrics["normalized_similarity"])
96 | elif measures == "ApproxName":
97 | if normalization == "Max":
98 | approx_name_max.append(final_metrics["normalized_similarity"])
99 | elif normalization == "Precision":
100 | approx_name_precision.append(
101 | final_metrics["normalized_similarity"]
102 | )
103 | elif normalization == "Recall":
104 | approx_name_recall.append(
105 | final_metrics["normalized_similarity"]
106 | )
107 | elif measures == "MultiProp":
108 | if normalization == "Max":
109 | multi_prop_max.append(final_metrics["normalized_similarity"])
110 | metrics.append(final_metrics)
111 | counts.append(count)
112 | nested_counts.append(nested_count)
113 | elif normalization == "Precision":
114 | multi_prop_precision.append(
115 | final_metrics["normalized_similarity"]
116 | )
117 | elif normalization == "Recall":
118 | multi_prop_recall.append(final_metrics["normalized_similarity"])
119 | # Compute and log metrics for generated text
120 | keys = set(key for d in metrics for key in d.keys())
121 | quantiles = [5, 10, 25, 50, 75, 90, 95]
122 |
123 | def compute_quantiles(data, quantiles):
124 | return {q: np.percentile(data, q) for q in quantiles}
125 |
126 | avg_metrics = {
127 | key: {
128 | "average": np.mean(
129 | [metric[key] for metric in metrics if key in metric.keys()]
130 | ),
131 | "quantiles": compute_quantiles(
132 | [metric[key] for metric in metrics if key in metric.keys()],
133 | quantiles,
134 | ),
135 | "raw_data": [metric[key] for metric in metrics if key in metric.keys()],
136 | }
137 | for key in keys
138 | }
139 | avg_metrics.update(
140 | {
141 | key
142 | + "_average": np.mean(
143 | [metric[key] for metric in metrics if key in metric.keys()]
144 | )
145 | for key in keys
146 | }
147 | )
148 | keys = set(key for d in counts for key in d.keys())
149 | total_metrics = {
150 | key: np.sum([count[key] for count in counts if key in count.keys()])
151 | for key in keys
152 | }
153 |
154 | outer_keys = set(key for d in nested_counts for key in d.keys())
155 | inner_keys = set(
156 | key
157 | for d in nested_counts
158 | for inner_dict in d.values()
159 | for key in inner_dict.keys()
160 | )
161 |
162 | total_nested_counts = {}
163 | for k in outer_keys:
164 | total_nested_counts[k] = {
165 | key: np.sum(
166 | [
167 | count_dict[k][key]
168 | for count_dict in nested_counts
169 | if key in count_dict[k].keys()
170 | ]
171 | )
172 | for key in inner_keys
173 | }
174 |
175 | property_metrics = {}
176 | for k in inner_keys:
177 | property_metrics[k] = {
178 | "acc_token": (
179 | total_nested_counts["per_property_acc_token"][k]
180 | / total_nested_counts["key_matches"][k]
181 | if total_nested_counts["key_matches"][k] > 0
182 | else 0
183 | ),
184 | "acc_aon": (
185 | total_nested_counts["per_property_acc_aon"][k]
186 | / total_nested_counts["key_matches"][k]
187 | if total_nested_counts["key_matches"][k] > 0
188 | else 0
189 | ),
190 | "key_coverage": (
191 | total_nested_counts["key_matches"][k]
192 | / total_nested_counts["target_key_occurance"][k]
193 | if total_nested_counts["target_key_occurance"][k] > 0
194 | else 0
195 | ),
196 | "key_precision": (
197 | total_nested_counts["key_matches"][k]
198 | / total_nested_counts["pred_key_occurance"][k]
199 | if total_nested_counts["pred_key_occurance"][k] > 0
200 | else 0
201 | ),
202 | }
203 | print("Target Count:", target_count)
204 | print("Generated Count:", generated_count)
205 | print("Target Entities without type:", target_entities_without_type)
206 | print("Generated Entities without type:", generated_entities_without_type)
207 |
208 | avg_metrics.update(
209 | {
210 | # "avg_target_entities": avg_target_entities,
211 | # "target_type_counts": target_type_counts,
212 | # "avg_generated_entities": avg_generated_entities,
213 | # "generated_type_counts": generated_type_counts,
214 | # "avg_target_entities_without_type": avg_target_entities_without_type,
215 | # "avg_generated_entities_without_type": avg_generated_entities_without_type,
216 | "combined_coverage": total_metrics["established_entity_matches"]
217 | / total_metrics["target_entities_no_dup"],
218 | "combined_precision": total_metrics["established_entity_matches"]
219 | / total_metrics["predicted_output_entities_no_dup"],
220 | }
221 | )
222 |
223 | avg_metrics.update(property_metrics)
224 |
225 | result = {}
226 | result["exact_name_max"] = np.mean(exact_name_max)
227 | result["exact_name_precision"] = np.mean(exact_name_precision)
228 | result["exact_name_recall"] = np.mean(exact_name_recall)
229 | result["approx_name_max"] = np.mean(approx_name_max)
230 | result["approx_name_precision"] = np.mean(approx_name_precision)
231 | result["approx_name_recall"] = np.mean(approx_name_recall)
232 | result["multi_prop_max"] = np.mean(multi_prop_max)
233 | result["multi_prop_precision"] = np.mean(multi_prop_precision)
234 | result["multi_prop_recall"] = np.mean(multi_prop_recall)
235 | result["target_count"] = target_count
236 | result["generated_count"] = generated_count
237 | result["target_type_counts"] = target_type_counts
238 | result["generated_type_counts"] = generated_type_counts
239 | result["target_entities_without_type"] = target_entities_without_type
240 | result["generated_entities_without_type"] = generated_entities_without_type
241 | result.update(avg_metrics)
242 | with open(prediction_path[:-5] + "_metrics.json", "w") as f:
243 | json.dump(result, f, indent=4)
244 | return result
245 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 |
4 | class DataManager(ABC):
5 | @abstractmethod
6 | def create_dataset(
7 | self, file_path: str, use_data: int = None, max_length: int = 512, **kwargs
8 | ):
9 | # return a pytorch dataset
10 | raise NotImplementedError
11 |
12 | @abstractmethod
13 | def create_data_loader(
14 | self,
15 | file_path: str,
16 | use_data: int,
17 | max_length: int,
18 | batch_size: int,
19 | shuffle: bool,
20 | **kwargs
21 | ):
22 | # return pytorch data loaders for training, validation, test
23 | raise NotImplementedError
24 |
--------------------------------------------------------------------------------
/data/dataloader_musee.py:
--------------------------------------------------------------------------------
1 | import json
2 | from copy import deepcopy
3 |
4 | import torch
5 | from torch.utils.data import DataLoader, Dataset, random_split
6 | from transformers import T5Tokenizer
7 |
8 | from . import DataManager
9 |
10 |
11 | class WikiDataStep1Manager(DataManager):
12 | class WikiDatasetStep1Filtered(Dataset):
13 | def __init__(
14 | self,
15 | data,
16 | tokenizer,
17 | max_length,
18 | top_entity=None,
19 | top_property=None,
20 | data_name=None,
21 | ):
22 | self.data = data
23 | self.tokenizer = tokenizer
24 | self.max_length = max_length
25 | # self.max_num_entity = max(
26 | # len(item.get("entities", {})) for item in data.values()
27 | # )
28 | if data_name in ["d3", "d3-hallu"]:
29 | self.max_num_entity = 20
30 | elif data_name in ["d2", "d2-hallu"]:
31 | self.max_num_entity = 10
32 | elif data_name == "toy":
33 | self.max_num_entity = 4
34 | self.max_entities = max(len(item["entities"]) for item in data.values())
35 | self.max_prop = max(
36 | len(item["entities"][entity])
37 | for item in data.values()
38 | for entity in item["entities"]
39 | )
40 | self.max_prop_length = 15
41 |
42 | # Initialize storage variables
43 | self.all_entity_types = []
44 | self.all_pks = []
45 | self.entity_type_counts = {}
46 | self.property_key_counts = {}
47 |
48 | self._compute_counts()
49 |
50 | # Filter by top entity and property
51 | if top_entity:
52 | self._filter_top_entities(top_entity)
53 | if top_property:
54 | self._filter_top_properties(top_property)
55 |
56 | self._filter_data()
57 |
58 | self.num_entity_types = len(self.all_entity_types)
59 | self.num_all_pks = len(self.all_pks)
60 |
61 | self._build_template()
62 |
63 | def _compute_counts(self):
64 | """Compute counts for entity types and property keys."""
65 | for item in self.data.values():
66 | for entity in item["entities"].values():
67 | # Exclude the "pk_type" key
68 | self.all_pks.extend([k for k in entity.keys() if k != "pk_type"])
69 | entity_type = entity["pk_type"]
70 | self.all_entity_types.append(entity_type)
71 |
72 | self.entity_type_counts[entity_type] = (
73 | self.entity_type_counts.get(entity_type, 0) + 1
74 | )
75 | for pk in entity.keys():
76 | if pk != "pk_type":
77 | self.property_key_counts[pk] = (
78 | self.property_key_counts.get(pk, 0) + 1
79 | )
80 |
81 | # Sorting entity_type_counts from high to low
82 | self.entity_type_counts = dict(
83 | sorted(
84 | self.entity_type_counts.items(),
85 | key=lambda item: item[1],
86 | reverse=True,
87 | )
88 | )
89 |
90 | # Sorting property_key_counts from high to low
91 | self.property_key_counts = dict(
92 | sorted(
93 | self.property_key_counts.items(),
94 | key=lambda item: item[1],
95 | reverse=True,
96 | )
97 | )
98 |
99 | # Sorting all_entity_types based on entity_type_counts order
100 | self.all_entity_types = sorted(
101 | self.all_entity_types,
102 | key=lambda x: self.entity_type_counts[x],
103 | reverse=True,
104 | )
105 |
106 | # Sorting all_pks based on property_key_counts order
107 | self.all_pks = sorted(
108 | self.all_pks, key=lambda x: self.property_key_counts[x], reverse=True
109 | )
110 |
111 | def _filter_top_entities(self, top_entity):
112 | """Filter top entities."""
113 | sorted_entities = sorted(
114 | self.entity_type_counts.items(), key=lambda x: x[1], reverse=True
115 | )
116 | self.all_entity_types = [e[0] for e in sorted_entities[:top_entity]]
117 | print(f"top_sorted_entities counts: {dict(sorted_entities[:top_entity])}\n")
118 |
119 | def _filter_top_properties(self, top_property):
120 | """Filter top properties."""
121 | sorted_properties = sorted(
122 | self.property_key_counts.items(), key=lambda x: x[1], reverse=True
123 | )
124 | self.all_pks = [p[0] for p in sorted_properties[:top_property]]
125 | print(
126 | f"top_sorted_properties counts: {dict(sorted_properties[:top_property])}\n"
127 | )
128 |
129 | def _filter_data(self):
130 | """Filter data items and entities."""
131 | # Filter out items with no entities of desired types
132 | keys_to_remove = [
133 | key
134 | for key, item in self.data.items()
135 | if not any(
136 | e["pk_type"] in self.all_entity_types
137 | for e in item["entities"].values()
138 | )
139 | ]
140 | for key in keys_to_remove:
141 | del self.data[key]
142 |
143 | # Filter out entities not having any of the selected top properties
144 | for item in self.data.values():
145 | entities_to_remove = [
146 | entity_key
147 | for entity_key, entity in item["entities"].items()
148 | if not (set(entity.keys()) - {"pk_type"}) & set(self.all_pks)
149 | ]
150 | for entity_key in entities_to_remove:
151 | del item["entities"][entity_key]
152 |
153 | # Remove items with no entities or only with "0" type
154 | keys_to_remove = [
155 | key
156 | for key, item in self.data.items()
157 | if not item.get("entities")
158 | or all(
159 | entity["pk_type"] == 0 for entity in item.get("entities").values()
160 | )
161 | ]
162 | for key in keys_to_remove:
163 | del self.data[key]
164 |
165 | def _build_template(self):
166 | """Post filtering setup to determine unique entity types and property keys."""
167 | # self.all_entity_types = sorted(list(set(self.all_entity_types)))
168 | # self.all_pks = sorted(list(set(self.all_pks)))
169 |
170 | # Create a template tensor to store property presence info
171 | self.template = torch.zeros(
172 | (self.num_entity_types, self.num_all_pks), dtype=torch.int
173 | )
174 | for item in self.data.values():
175 | for entity_name in item["entities"]:
176 | entity = item["entities"][entity_name]
177 | if entity["pk_type"] in self.all_entity_types:
178 | for pk in entity:
179 | if pk in self.all_pks:
180 | self.template[
181 | self.all_entity_types.index(entity["pk_type"])
182 | ][self.all_pks.index(pk)] = 1
183 | print("template sum by E_type:", self.template.sum(1))
184 | print("template sum by Pk:", self.template.sum(0))
185 |
186 | def __len__(self):
187 | return len(self.data)
188 |
189 | def id2entity(self):
190 | # index 0 is for padded (not-real) entity
191 | return {i + 1: entity for i, entity in enumerate(self.all_entity_types)}
192 |
193 | def id2property(self):
194 | return {i: property_key for i, property_key in enumerate(self.all_pks)}
195 |
196 | def get_entity_label(self):
197 | entity_label = {}
198 | for i, e in enumerate(self.all_entity_types):
199 | entity_label[e] = i + 1
200 | entity_label["padding"] = 0
201 | return entity_label
202 |
203 | def get_all_template(self):
204 | return self.template
205 |
206 | def get_template(self, entity_type: str):
207 | return self.template[self.all_entity_types.index(entity_type)]
208 |
209 | def sort_entity_values(self, entity_values):
210 | # Function to get ID for a key
211 | def get_id(key):
212 | return self.tokenizer.convert_tokens_to_ids(key)
213 |
214 | # Excluding 'pk_type' and sorting the remaining keys based on their IDs
215 | sorted_keys = sorted(
216 | [key for key in entity_values if key != "pk_type"], key=get_id
217 | )
218 |
219 | # Constructing the sorted dictionary
220 | sorted_dict = {"pk_type": entity_values["pk_type"]}
221 | for key in sorted_keys:
222 | sorted_dict[key] = entity_values[key]
223 |
224 | return sorted_dict
225 |
226 | def __getitem__(self, idx):
227 | item = self.data[list(self.data.keys())[idx]]
228 | description = item.get("description", None)
229 | entities = item.get("entities", None)
230 |
231 | context = item["description"].replace("\n", "")
232 |
233 | context = context.strip()
234 | context = context.encode("ascii", "ignore")
235 | context = context.decode()
236 | while context.find(" ") != -1:
237 | context = context.replace(" ", " ")
238 | context = context.replace(" ,", ",")
239 | context = context.replace(" .", ".")
240 | context = context.replace(" ?", "?")
241 | context = context.replace(" !", "!")
242 | context = context.replace(" :", ":")
243 | context = context.replace(" ;", ";")
244 | context = context.replace("^", "")
245 | context = context.replace("<", "")
246 | context = context.replace(">", "")
247 | context = context.replace(" !", "!")
248 | context = context.replace("~", "")
249 | context = context.replace("\\", "")
250 | context = context.replace("\t", " ")
251 | context = context.replace("{", "")
252 | context = context.replace("}", "")
253 | context = context.strip()
254 |
255 | """
256 | ### Obtain input_ids and attention mask
257 | """
258 | src_text = f"{description}"
259 | src_tokenized = self.tokenizer.encode_plus(
260 | src_text,
261 | max_length=self.max_length,
262 | padding="max_length",
263 | return_attention_mask=True,
264 | return_tensors="pt",
265 | truncation=True,
266 | )
267 | input_ids = src_tokenized["input_ids"].flatten() # (seq_len, )
268 | attention_mask = src_tokenized["attention_mask"].flatten() # (seq_len, )
269 |
270 |
271 | # Extracting entity names and formatting them with [SEP] token
272 | entity_names = [entity["pk_entity_name"] for entity in entities.values()]
273 | entity_names = entity_names[
274 | : self.max_num_entity
275 | ] # only process max_num_entity
276 | real_labels_ent_name = (
277 | " " + " ".join(entity_names) + " "
278 | )
279 | labels_ent_name = self.tokenizer(
280 | real_labels_ent_name,
281 | max_length=self.max_num_entity * 6,
282 | padding="max_length",
283 | truncation=True,
284 | add_special_tokens=True,
285 | return_tensors="pt",
286 | )["input_ids"].squeeze(0)
287 |
288 | # print("real_labels_ent_name:", real_labels_ent_name)
289 | # print("labels_ent_name:", labels_ent_name)
290 |
291 | # Create the attention mask for the entity positions
292 | # attention_mask_ent = (labels_ent != -1).long()
293 | # attention_mask_ent_tokenized = (labels_ent_tokenized != 0).long()
294 | attention_mask_ent_name = (labels_ent_name != 0).long()
295 |
296 | """
297 | ### Labels for Step2: Obtain labels for Entity type and property key
298 | """
299 | end_token_id = self.tokenizer.eos_token_id
300 | labels_pk = torch.full(
301 | (self.max_num_entity, self.num_all_pks + 2), self.tokenizer.pad_token_id
302 | ) # (max_num_Entity, num_all_pks+2)
303 |
304 | for i, (_, entity_values) in enumerate(entities.items()):
305 | entity_values = self.sort_entity_values(entity_values)
306 | if (
307 | i >= self.max_num_entity
308 | ): # We only process up to max_num_entity entities
309 | break
310 |
311 | # Get the special token ID for the entity type
312 | ent_type_id = self.tokenizer.convert_tokens_to_ids(
313 | entity_values["pk_type"]
314 | )
315 | labels_pk[i, 0] = ent_type_id
316 |
317 | # Use a counter for the position in the label_ids tensor
318 | pk_counter = 1
319 |
320 | # Iterate over properties
321 | for pk in entity_values:
322 | if pk not in [
323 | "pk_type",
324 | "pk_entity_name",
325 | ]: # type is already added. No need to predict ent_name
326 | # Use the property key as a special token to get its ID
327 | pk_id = self.tokenizer.convert_tokens_to_ids(pk)
328 | if (
329 | pk_counter < labels_pk.size(1) - 1
330 | ): # Leave space for the end token
331 | labels_pk[i, pk_counter] = pk_id
332 | pk_counter += 1
333 |
334 | # Add the end token if there's at least one property key
335 | if pk_counter < labels_pk.size(1):
336 | labels_pk[i, pk_counter] = end_token_id
337 |
338 | # The attention mask is binary: 1 for special tokens, 0 for padding tokens
339 | attention_mask_pk = (
340 | labels_pk != self.tokenizer.pad_token_id
341 | ).long() # (max_num_Entity, num_all_pks+2)
342 |
343 | """
344 | ### Labels for Step3: Property values (max_num_Entity, num_all_pks, max_prop_len)
345 | """
346 | labels_pv = torch.full(
347 | (self.max_num_entity, self.num_all_pks, self.max_prop_length),
348 | self.tokenizer.pad_token_id,
349 | ) # Initialize with padding token ID
350 |
351 | for i, (_, entity_values) in enumerate(entities.items()):
352 | entity_values = self.sort_entity_values(entity_values)
353 | if (
354 | i >= self.max_num_entity
355 | ): # Only process up to max_num_entity entities
356 | break
357 |
358 | # Counter for the property key
359 | pk_counter = 0
360 | for pk in entity_values:
361 | if pk not in [
362 | "pk_type",
363 | "pk_entity_name",
364 | ]: # type is already added. No need to predict ent_name
365 | # Encode the property value
366 | encoded_prop = self.tokenizer.encode(
367 | entity_values[pk]
368 | + self.tokenizer.eos_token, # Encode the property value
369 | add_special_tokens=False,
370 | max_length=self.max_prop_length,
371 | padding="max_length",
372 | truncation=True,
373 | return_tensors="pt",
374 | ).flatten()
375 | # Place the encoded property value in the tensor
376 | labels_pv[i, pk_counter, :] = encoded_prop
377 | pk_counter += 1
378 |
379 | # Create a mask for the encoded properties
380 | attention_mask_pv = (labels_pv != self.tokenizer.pad_token_id).long()
381 |
382 | return {
383 | "input_ids": input_ids, # (seq_len, )
384 | # "labels_ent": labels_ent, # (max_num_Entity * 2, )
385 | # "labels_ent_tokenized": labels_ent_tokenized, # (max_num_Entity * 3, )
386 | "labels_ent_name": labels_ent_name, # (max_num_Entity * 6, )
387 | "real_labels_ent_name": real_labels_ent_name, # (max_num_Entity * 6, )
388 | "labels_pk": labels_pk, # (max_num_Entity, num_all_pks+2)
389 | "labels_pv": labels_pv, # (max_num_Entity, num_all_pks, max_prop_len)
390 | "attention_mask": attention_mask,
391 | # "attention_mask_ent": attention_mask_ent,
392 | # "attention_mask_ent_tokenized": attention_mask_ent_tokenized,
393 | "attention_mask_ent_name": attention_mask_ent_name,
394 | "attention_mask_pk": attention_mask_pk,
395 | "attention_mask_pv": attention_mask_pv,
396 | }
397 |
398 | def create_dataset(
399 | self, file_path: str, use_data: int = None, max_length: int = 1024, **kwargs
400 | ):
401 | pretrained_model_name = kwargs.get("model_name", None)
402 | tokenizer = T5Tokenizer.from_pretrained(pretrained_model_name)
403 | special_tokens_need_to_add = []
404 | ent_type_tokens = [] # List for tokens starting with "ent_type_"
405 | pk_tokens = [] # List for tokens starting with "pk_"
406 |
407 | with open(file_path, "r") as f:
408 | data = json.load(f)
409 |
410 | # Modify the keys in the data
411 | for k, entities in data.items():
412 | for entity_id, entity in entities["entities"].items():
413 | new_entity = {}
414 | for prop_key, prop_value in entity.items():
415 | if prop_key == "type":
416 | new_key = f"pk_{prop_key}"
417 | new_value = f'ent_type_{prop_value.replace(" ", "_")}'
418 | new_entity[new_key] = new_value
419 | special_tokens_need_to_add.append(new_value)
420 | ent_type_tokens.append(new_value) # Add to ent_type_tokens
421 | # elif prop_key != "entity name": # we do not need to train / predict entity name for pk and pv
422 | else:
423 | new_key = f'pk_{prop_key.replace(" ", "_")}'
424 | new_entity[new_key] = prop_value
425 | special_tokens_need_to_add.append(new_key)
426 | pk_tokens.append(new_key) # Add to pk_tokens
427 | entities["entities"][entity_id] = new_entity
428 |
429 | special_tokens_need_to_add = sorted(list(set(special_tokens_need_to_add)))
430 | ent_type_tokens = sorted(
431 | list(set(ent_type_tokens))
432 | ) # Remove duplicates and sort
433 | pk_tokens = sorted(list(set(pk_tokens))) # Remove duplicates and sort
434 | # print("special_tokens_need_to_add:", len(special_tokens_need_to_add), special_tokens_need_to_add)
435 | tokenizer.add_tokens(special_tokens_need_to_add)
436 |
437 | print("Full data length:", len(data))
438 | data = dict(list(data.items())[:use_data])
439 | print("Used data length:", len(data))
440 |
441 | top_entity = kwargs.get("top_entity", None)
442 | top_property = kwargs.get("top_property", None)
443 | data_name = kwargs.get("data_name", None)
444 | dataset = self.WikiDatasetStep1Filtered(
445 | data,
446 | tokenizer,
447 | max_length=max_length,
448 | top_entity=top_entity,
449 | top_property=top_property,
450 | data_name=data_name,
451 | )
452 | print("Max_num_entity:", dataset.max_num_entity)
453 | print("Max_num_pk:", dataset.num_all_pks)
454 | print("Max_prop_len:", dataset.max_prop_length)
455 | print("*************")
456 |
457 | return (
458 | dataset,
459 | tokenizer,
460 | special_tokens_need_to_add,
461 | ent_type_tokens,
462 | pk_tokens,
463 | )
464 |
465 | def create_data_loader(
466 | self,
467 | file_path: str,
468 | use_data: int,
469 | max_length: int,
470 | batch_size: int,
471 | shuffle: bool,
472 | **kwargs,
473 | ):
474 | dataset, tokenizer, _, _, _ = self.create_dataset(
475 | file_path, use_data, max_length, **kwargs
476 | )
477 |
478 | if len(dataset) < 10:
479 | train_size = len(dataset) - 2
480 | val_size = 1
481 | test_size = 1
482 | else:
483 | train_size = int(0.8 * len(dataset))
484 | val_size = int(0.1 * len(dataset))
485 | test_size = len(dataset) - train_size - val_size
486 |
487 | train_dataset, val_dataset, test_dataset = random_split(
488 | dataset, [train_size, val_size, test_size]
489 | )
490 |
491 | train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle)
492 | val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle)
493 | test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=shuffle)
494 | print("train_dataloader:", len(train_loader))
495 | print("val_dataloader:", len(val_loader))
496 | print("test_dataloader:", len(test_loader))
497 |
498 | return train_loader, val_loader, test_loader, dataset, tokenizer
499 |
--------------------------------------------------------------------------------
/experiment_musee.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | import torch
4 | import wandb
5 | from args import parse_args
6 | from metrics import evaluate
7 | from peft import LoraConfig, PeftModel, get_peft_model
8 | from syne_tune import Reporter
9 | from torch.utils.data import DataLoader
10 | from transformers import get_linear_schedule_with_warmup
11 | from utils import get_attention_paths, print_trainable_parameters, set_seed
12 |
13 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
14 |
15 |
16 | def generate_final_json(results, ground_truth_path):
17 | # Load the ground truth JSON file
18 | with open(ground_truth_path, "r") as file:
19 | ground_truth_data = json.load(file)
20 |
21 | # Function to process each prediction and return formatted entities
22 | def process_prediction(ent_tokens, pk_tokens, pv_tokens):
23 | print("ent_tokens, pk_tokens, pv_tokens:", ent_tokens, pk_tokens, pv_tokens)
24 | entities = {}
25 | for i, (ent_name, pk_token, pv_token) in enumerate(
26 | zip(ent_tokens, pk_tokens, pv_tokens)
27 | ):
28 | pk_parts = pk_token.split()
29 |
30 | # Skip processing if pk_parts is empty
31 | if len(pk_parts) == 0:
32 | continue
33 |
34 | # Extracting entity type
35 | entity_type = (
36 | pk_parts[0].replace("ent_type_", "").replace("_", " ")
37 | if "ent_type_" in pk_parts[0]
38 | else "unknown"
39 | )
40 |
41 | entity_info = {"type": entity_type}
42 | entity_info["entity name"] = ent_name # No need to predict entity name
43 | for j, key in enumerate(
44 | pk_parts[1:]
45 | ): # Skip the first token, which is the type
46 | prop_key = key.replace("pk_", "").replace("_", " ")
47 | if "ent type" in prop_key:
48 | continue
49 | entity_info[prop_key] = pv_token[j] if j < len(pv_token) else ""
50 | # if entity_type == "human":
51 | # entity_info["given name"] = ent_name.split()[0]
52 | # entity_info["family name"] = ent_name.split()[-1]
53 |
54 | entities[str(i)] = entity_info
55 | return entities
56 |
57 | # Create the final JSON object
58 | final_json = {}
59 | for doc_id, prediction in zip(ground_truth_data, results):
60 | ent_tokens = prediction.get("predict_ent", [])
61 | pk_tokens = prediction.get("predict_pk", [])
62 | pv_tokens = prediction.get("predict_pv", [])
63 |
64 | entities = process_prediction(ent_tokens, pk_tokens, pv_tokens)
65 | final_json[doc_id] = {
66 | "doc_id": doc_id,
67 | "description": ground_truth_data[doc_id]["description"],
68 | "entities": entities,
69 | }
70 |
71 | return final_json
72 |
73 |
74 | def experiment(args):
75 | # Due to the forced_decoder_ids does not support batch, we have to set batch_size=1 for inference
76 | if args.mode == "test":
77 | args.batch_size = 1
78 |
79 | if args.model_choice == "MuSEE":
80 | from trainer.trainer_musee import Trainer_E_Pk_Pv
81 |
82 | trainer = Trainer_E_Pk_Pv()
83 | if args.log_wandb:
84 | wandb.login()
85 | from data.dataloader_musee import WikiDataStep1Manager
86 |
87 | manager = WikiDataStep1Manager()
88 |
89 | if args.dataset == "toy":
90 | data_abbrev = "toy"
91 | train_data_path = "data/toy/D3_toy.json"
92 | val_data_path = "data/toy/D3_toy.json"
93 | test_data_path = "data/toy/D3_toy.json"
94 | # test_data_path = "data/toy/dummy.json"
95 | use_data = 100
96 | max_length = 512
97 | elif args.dataset == "gpt4":
98 | data_abbrev = "d2"
99 | train_data_path = "data/D2_final/D2_train_final.json"
100 | val_data_path = "data/D2_final/D2_val_final.json"
101 | test_data_path = "data/D2_final/D2_test_final.json"
102 | use_data = 20000
103 | max_length = 512
104 | elif args.dataset == "wikidata":
105 | data_abbrev = "d3"
106 | train_data_path = "data/D3_final/D3_train_final.json"
107 | val_data_path = "data/D3_final/D3_val_final.json"
108 | test_data_path = "data/D3_final/D3_test_final.json"
109 | use_data = 20000
110 | max_length = 512
111 | elif args.dataset == "wikidata-hallu":
112 | data_abbrev = "d3"
113 | train_data_path = "data/D3_final/D3_train_final.json"
114 | val_data_path = "data/D3_final/D3_val_final.json"
115 | test_data_path = "data/D3_final/D3_test_final_hallu_5k.json"
116 | use_data = 20000
117 | max_length = 512
118 | elif args.dataset == "internal":
119 | data_abbrev = "inter"
120 | raise NotImplementedError
121 |
122 | # Get dataloader
123 | dataset, tokenizer, special_tokens_need_to_add, ent_type_tokens, pk_tokens = (
124 | manager.create_dataset(
125 | file_path=train_data_path,
126 | model_name=args.pretrained_model_name,
127 | use_data=use_data,
128 | max_length=max_length,
129 | batch_size=args.batch_size,
130 | shuffle=False,
131 | if_filter=True,
132 | top_entity=args.top_entity,
133 | top_property=args.top_property,
134 | data_name=data_abbrev,
135 | )
136 | )
137 | val_dataset, _, _, _, _ = manager.create_dataset(
138 | file_path=val_data_path,
139 | model_name=args.pretrained_model_name,
140 | use_data=use_data,
141 | max_length=max_length,
142 | batch_size=args.batch_size,
143 | shuffle=False,
144 | if_filter=True,
145 | top_entity=args.top_entity,
146 | top_property=args.top_property,
147 | data_name=data_abbrev,
148 | )
149 | test_dataset, _, _, _, _ = manager.create_dataset(
150 | file_path=test_data_path,
151 | model_name=args.pretrained_model_name,
152 | use_data=use_data,
153 | max_length=max_length,
154 | batch_size=args.batch_size,
155 | shuffle=False,
156 | if_filter=True,
157 | top_entity=args.top_entity,
158 | top_property=args.top_property,
159 | data_name=data_abbrev,
160 | )
161 | # Get the indices of the new tokens
162 | added_new_token_ids = tokenizer.convert_tokens_to_ids(
163 | special_tokens_need_to_add
164 | )
165 | added_ent_type_tokens = tokenizer.convert_tokens_to_ids(ent_type_tokens)
166 | added_pk_tokens = tokenizer.convert_tokens_to_ids(pk_tokens)
167 | print("added_new_token_ids:", added_new_token_ids)
168 | print("added_ent_type_tokens:", added_ent_type_tokens)
169 | print("added_pk_tokens:", added_pk_tokens)
170 |
171 | train_dataloader = DataLoader(
172 | dataset, batch_size=args.batch_size, shuffle=False
173 | )
174 | val_dataloader = DataLoader(
175 | val_dataset, batch_size=args.batch_size, shuffle=False
176 | )
177 | test_dataloader = DataLoader(
178 | test_dataset, batch_size=args.batch_size, shuffle=False
179 | )
180 |
181 | max_seq_length = dataset.max_length
182 | num_entity_types = dataset.num_entity_types
183 | max_num_entity = dataset.max_num_entity
184 | num_property_keys = dataset.num_all_pks
185 | all_entity_types = dataset.all_entity_types
186 | entity_type_counts = dataset.entity_type_counts
187 | entity_type_counts["0"] = use_data * max_num_entity - sum(
188 | entity_type_counts.values()
189 | )
190 | entity_type_counts = {
191 | k: v
192 | for k, v in sorted(
193 | entity_type_counts.items(), key=lambda item: item[1], reverse=True
194 | )
195 | }
196 | property_key_counts = dataset.property_key_counts
197 | print("-----------")
198 | print("max_seq_length:", max_seq_length)
199 | print("num_entity_types:", num_entity_types)
200 | print("max_num_entity:", max_num_entity)
201 | print("num_property_keys:", num_property_keys)
202 | print("entity_type_counts:", len(entity_type_counts), entity_type_counts)
203 | print("property_key_counts:", property_key_counts)
204 |
205 | # type_weights = compute_inverse_frequency_weights(entity_type_counts, num_entity_types).to(device)
206 | # print("type_weights:", type_weights)
207 |
208 | # original_template = dataset.get_all_template().numpy()
209 | # all_zero_row = np.zeros(
210 | # original_template.shape[1], dtype=original_template.dtype
211 | # )
212 | # template = np.vstack(
213 | # (all_zero_row, original_template)
214 | # ) # add all-zero row for type 0
215 | # template = torch.tensor(template).to(device)
216 | # print("template:", template.shape)
217 | from trainer.trainer_musee import Predictor_E_Pk_Pv
218 |
219 | model = Predictor_E_Pk_Pv(
220 | pretrained_model_name=args.pretrained_model_name,
221 | max_seq_length=max_seq_length,
222 | max_num_entity=max_num_entity,
223 | tokenizer=tokenizer,
224 | ).to(device)
225 | model.t5_model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=32)
226 | print_trainable_parameters(model)
227 |
228 | mask_token, sep_token = "", ""
229 | mask_token_id = torch.tensor(
230 | tokenizer.encode(mask_token, add_special_tokens=False)[0]
231 | ).item()
232 | sep_token_id = torch.tensor(
233 | tokenizer.encode(sep_token, add_special_tokens=False)[0]
234 | ).item()
235 |
236 | vocab_size = model.t5_model.get_input_embeddings().weight.size(0)
237 |
238 | print("vocab_size:", vocab_size)
239 | print("mask_token_id:", mask_token_id)
240 | print("sep_token_id:", sep_token_id)
241 | print("--------------------")
242 |
243 | # # Set up wandb
244 | if args.log_wandb:
245 | run_name = "lr{}-wd{}-{}".format(args.lr, args.weight_decay, args.loss_mode)
246 | wandb.init(
247 | project="MuSEE-full-{}-{}-{}-lora-{}-init-{}".format(
248 | data_abbrev,
249 | args.pretrained_model_name,
250 | args.use_lora,
251 | args.loss_mode,
252 | args.use_better_init,
253 | ),
254 | config=args,
255 | name=run_name, # set the run name here
256 | )
257 |
258 | save_path = (
259 | f"saved/best_model/MuSEE/MuSEE_{data_abbrev}_m{args.pretrained_model_name}_"
260 | f"lr{args.lr}_wd{args.weight_decay}_{args.loss_mode}_lora_{args.use_lora}_init_{args.use_better_init}"
261 | f"best_model"
262 | )
263 |
264 | if args.use_better_init:
265 | print("Better initialize the special tokens' embeddings")
266 | print(
267 | "special_tokens_need_to_add:",
268 | len(special_tokens_need_to_add),
269 | special_tokens_need_to_add,
270 | )
271 | # Get the embeddings layer from the model
272 | embedding_layer = model.t5_model.get_input_embeddings()
273 | print(
274 | "old:",
275 | model.t5_model.get_input_embeddings().weight.shape,
276 | model.t5_model.get_input_embeddings().weight.sum(),
277 | )
278 |
279 | # Calculate new embeddings
280 | new_token_embeddings = []
281 | for token in special_tokens_need_to_add:
282 | # Tokenize the special token into subwords
283 | token = token.replace("ent_type", "")
284 | token = token.replace("pk", "")
285 | token = token.replace("_", " ")
286 | subtokens = tokenizer.tokenize(token)
287 |
288 | # Get the embeddings for the subtokens
289 | subtoken_ids = tokenizer.convert_tokens_to_ids(subtokens)
290 | subtoken_embeddings = embedding_layer.weight[subtoken_ids]
291 |
292 | # Calculate the average embedding
293 | average_embedding = subtoken_embeddings.mean(dim=0)
294 |
295 | # Add Gaussian noise to the average embedding
296 | noise = torch.randn(average_embedding.size()) * args.noise_std_dev
297 | new_embedding = average_embedding + noise.to(device)
298 |
299 | # Append to the list of new token embeddings
300 | new_token_embeddings.append(new_embedding)
301 |
302 | # Convert the list to a tensor
303 | new_token_embeddings = torch.stack(new_token_embeddings)
304 |
305 | # Set the embeddings for the new tokens in the model
306 | with torch.no_grad():
307 | # Get the indices of the new tokens
308 | new_token_ids = tokenizer.convert_tokens_to_ids(
309 | special_tokens_need_to_add
310 | )
311 | # Update the embeddings for these tokens
312 | embedding_layer.weight[new_token_ids] = new_token_embeddings
313 | print(
314 | "new:",
315 | model.t5_model.get_input_embeddings().weight.shape,
316 | model.t5_model.get_input_embeddings().weight.sum(),
317 | )
318 |
319 | # Set embedding layer as trainable
320 | model.t5_model.shared.weight.requires_grad = True
321 |
322 | if args.mode == "train":
323 | if args.use_lora:
324 | target_modules = get_attention_paths(model)
325 | modules_to_save = ["shared"]
326 |
327 | lora_config = LoraConfig(
328 | r=args.lora_r,
329 | lora_alpha=args.lora_alpha,
330 | lora_dropout=args.lora_dropout,
331 | target_modules=target_modules,
332 | modules_to_save=modules_to_save,
333 | )
334 |
335 | model = get_peft_model(model, lora_config)
336 | print_trainable_parameters(model)
337 |
338 | optimizer = torch.optim.AdamW(
339 | model.parameters(), lr=args.lr, weight_decay=args.weight_decay
340 | )
341 | # Set up the learning rate scheduler
342 | total_steps = len(train_dataloader) * args.epochs
343 | scheduler = get_linear_schedule_with_warmup(
344 | optimizer, num_warmup_steps=0, num_training_steps=total_steps
345 | )
346 | report = Reporter()
347 | trainer.train(
348 | save_path,
349 | model,
350 | train_dataloader,
351 | val_dataloader,
352 | optimizer,
353 | scheduler,
354 | args.epochs,
355 | device=device,
356 | log_wandb=args.log_wandb,
357 | use_lora=args.use_lora,
358 | alpha=args.alpha,
359 | added_ent_type_tokens=added_ent_type_tokens,
360 | added_pk_tokens=added_pk_tokens,
361 | loss_mode=args.loss_mode,
362 | reporter=report,
363 | )
364 |
365 | if args.use_lora:
366 | print(
367 | f"t5_model.shared.original_module",
368 | model.t5_model.shared.original_module.weight.sum(),
369 | )
370 | print(
371 | f"t5_model.shared.modules_to_save",
372 | model.t5_model.shared.modules_to_save["default"].weight.sum(),
373 | )
374 |
375 | elif args.mode == "test":
376 | save_path = (
377 | f"saved/best_model/MuSEE/MuSEE_{data_abbrev}_m{args.pretrained_model_name}_"
378 | f"lr{args.lr}_wd{args.weight_decay}_{args.loss_mode}_lora_{args.use_lora}_init_{args.use_better_init}"
379 | f"best_model"
380 | )
381 | print("save_path:", save_path)
382 | if args.use_lora:
383 | model = PeftModel.from_pretrained(model, save_path)
384 | print(
385 | f"t5_model.shared.original_module",
386 | model.t5_model.shared.original_module.weight.sum(),
387 | )
388 | print(
389 | f"t5_model.shared.modules_to_save",
390 | model.t5_model.shared.modules_to_save["default"].weight.sum(),
391 | )
392 | model = model.merge_and_unload()
393 | else:
394 | model.load_state_dict(
395 | torch.load(f"{save_path}.pt", map_location=device)
396 | )
397 |
398 | print(
399 | "after load pretrained (get_input_embeddings):",
400 | model.t5_model.get_input_embeddings().weight.shape,
401 | model.t5_model.get_input_embeddings().weight.sum(),
402 | )
403 |
404 | print(
405 | "after load pretrained (shared):",
406 | model.t5_model.shared.weight.shape,
407 | model.t5_model.shared.weight.sum(),
408 | )
409 |
410 | model.eval()
411 | # generate json output
412 | # id2entity, id2property = dataset.id2entity(), dataset.id2property()
413 | results = Trainer_E_Pk_Pv.generate_full_json_output(
414 | model,
415 | test_dataloader,
416 | added_ent_type_tokens,
417 | added_pk_tokens,
418 | tokenizer,
419 | device,
420 | mode=args.mode,
421 | )
422 |
423 | final_json = generate_final_json(results, test_data_path)
424 | print("final_json:", json.dumps(final_json, indent=4))
425 |
426 | # Save to a JSON file
427 | prediction_path = (
428 | f"saved/best_model/MuSEE/saved_json/MuSEE_{data_abbrev}_m{args.pretrained_model_name}_"
429 | f"lr{args.lr}_wd{args.weight_decay}_{args.loss_mode}_lora_{args.use_lora}_init_{args.use_better_init}.json"
430 | )
431 | with open(prediction_path, "w", encoding="utf-8") as file:
432 | json.dump(final_json, file, ensure_ascii=False, indent=4)
433 |
434 | metrics = evaluate(test_data_path, prediction_path)
435 |
436 |
437 | def run():
438 | args = parse_args()
439 | set_seed(args.seed)
440 | experiment(args)
441 |
442 |
443 | if __name__ == "__main__":
444 | run()
445 |
--------------------------------------------------------------------------------
/img/metric.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/Structured-Entity-Extraction/8e3d0fbb009874b5ba35a450de1ff37995b8102e/img/metric.png
--------------------------------------------------------------------------------
/img/model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/Structured-Entity-Extraction/8e3d0fbb009874b5ba35a450de1ff37995b8102e/img/model.png
--------------------------------------------------------------------------------
/metric.py:
--------------------------------------------------------------------------------
1 | import json
2 | from copy import deepcopy
3 |
4 | import torch
5 | from scipy.sparse import csr_matrix
6 | from scipy.sparse.csgraph import min_weight_full_bipartite_matching
7 |
8 | # from utils import remove_duplicates_and_postprocess
9 |
10 |
11 | def compute_distances(gt, pred, measures, weights_pks=None):
12 | """
13 | Args:
14 | gt: list of ground-truth entities, where each entity has
15 | - a 't' field consisting of the one-hot type torch.tensor
16 | - a 'pk' field consisting of multihot torch.tensor indicating the present pks
17 | - a 'pv' field containing a list of the property values, in an order consistent with the 'pk' field.
18 | pred: list of predicted entities
19 | measures: list of distance measures to apply
20 | weights_pks: torch.tensor defining a weight for each property key to compute weighted averages in the distance metrics.
21 | """
22 | assert measures in ["ExactName", "ApproxName", "MultiProp"]
23 |
24 | # if weights_pks is None:
25 | # weights_pks = torch.ones_like(gt[0]["pk"])
26 |
27 | # def entity_type_distance_ce(e1, e2):
28 | # t1, t2 = e1["t"].float(), e2["t"].float()
29 | # epsilon = 1e-5
30 | # t2 = torch.clamp(t2, min=epsilon) # Ensure values are not too close to zero
31 | # return -torch.sum(t1 * torch.log(t2))
32 | #
33 | # def entity_type_distance_acc(e1, e2):
34 | # acc = (torch.argmax(e1["t"]) == torch.argmax(e2["t"])).float()
35 | # return 1 - acc
36 | def entity_name_distance_approx(
37 | e1, e2, weights
38 | ): # save as pv but only weights for name is 1 rest is 0
39 | v1, v2 = e1["pv"], e2["pv"]
40 | # Use jaccard similarity to compute the distance of the name
41 | # Split each property value into tokens (words)
42 | tokens_v1 = [
43 | set(value.lower().split())
44 | for index, value in enumerate(v1)
45 | if weights[index] == 1
46 | ] # only for name
47 | tokens_v2 = [
48 | set(value.lower().split())
49 | for index, value in enumerate(v2)
50 | if weights[index] == 1
51 | ] # only for name
52 | jaccard_similarities = []
53 | _weights = weights.clone()
54 | for i, (t1, t2) in enumerate(zip(tokens_v1, tokens_v2)):
55 | # Compute the Jaccard similarity for the token sets
56 | intersection_size = len(t1.intersection(t2))
57 | union_size = len(t1.union(t2))
58 | if union_size == 0:
59 | jaccard_sim = 0.0
60 | _weights[i] = 0.0
61 | else:
62 | jaccard_sim = intersection_size / union_size
63 | jaccard_similarities.append(jaccard_sim)
64 | dist = (
65 | 1 - (torch.tensor(jaccard_similarities) * _weights).sum() / _weights.sum()
66 | )
67 | return dist
68 |
69 | def entity_name_distance_exact(
70 | e1, e2, weights
71 | ): # save as pv but only weights for name is 1 rest is 0
72 | v1, v2 = e1["pv"], e2["pv"]
73 | v1 = [
74 | value for index, value in enumerate(v1) if weights[index] == 1
75 | ] # only for name
76 | v2 = [
77 | value for index, value in enumerate(v2) if weights[index] == 1
78 | ] # only for name
79 | matching = torch.tensor([int(a.lower() == b.lower()) for a, b in zip(v1, v2)])
80 | if len(matching) == 0:
81 | return 1.0
82 | else:
83 | return 1 - (matching.float() * weights).sum() / weights.sum()
84 |
85 | # def property_key_distance_bce(e1, e2):
86 | # k1, k2 = e1["pk"].float(), e2["pk"].float()
87 | # epsilon = 1e-5
88 | # k2 = torch.clamp(
89 | # k2, min=epsilon, max=1 - epsilon
90 | # ) # Ensures values are between epsilon and 1-epsilon
91 | # return -torch.sum(k1 * torch.log(k2) + (1 - k1) * torch.log(1 - k2))
92 | #
93 | # def property_key_distance_acc(e1, e2):
94 | # k1, k2 = e1["pk"].float(), e2["pk"].float()
95 | # k2_preds = (k2 >= 0.5).float()
96 | # corrects = (k2_preds == k1).float().sum()
97 | # acc = corrects / len(k1)
98 | # return 1 - acc
99 | #
100 | # def property_value_prop_distance_acc(e1, e2, weights):
101 | # v1, v2 = e1["pv"], e2["pv"]
102 | #
103 | # matching = torch.tensor([int(a.lower() == b.lower()) for a, b in zip(v1, v2)])
104 | #
105 | # return 1 - (matching.float() * weights).sum() / weights.sum()
106 |
107 | def property_value_token_distance_acc(e1, e2, weights):
108 | v1, v2 = e1["pv"], e2["pv"]
109 |
110 | # Split each property value into tokens (words)
111 | tokens_v1 = [set(value.lower().split()) for value in v1]
112 | tokens_v2 = [set(value.lower().split()) for value in v2]
113 |
114 | jaccard_similarities = []
115 | _weights = weights.clone()
116 | for i, (t1, t2) in enumerate(zip(tokens_v1, tokens_v2)):
117 | # Compute the Jaccard similarity for the token sets
118 | intersection_size = len(t1.intersection(t2))
119 | union_size = len(t1.union(t2))
120 | if union_size == 0:
121 | jaccard_sim = 0.0
122 | _weights[i] = 0.0
123 | else:
124 | jaccard_sim = intersection_size / union_size
125 | jaccard_similarities.append(jaccard_sim)
126 | dist = (
127 | 1 - (torch.tensor(jaccard_similarities) * _weights).sum() / _weights.sum()
128 | )
129 | return dist
130 |
131 | distances = torch.zeros((len(gt), len(pred)))
132 |
133 | for i, g in enumerate(gt):
134 | for j, p in enumerate(pred):
135 | distance = 0
136 | if measures == "ExactName":
137 | distance = entity_name_distance_exact(g, p, weights_pks)
138 | elif measures == "ApproxName":
139 | distance = entity_name_distance_approx(g, p, weights_pks)
140 | elif measures == "MultiProp": # We can also use bce here
141 | distance = property_value_token_distance_acc(g, p, weights_pks)
142 | # if "E-CE" in measures:
143 | # distance += entity_type_distance_ce(g, p)
144 | # if "E-ACC" in measures:
145 | # distance += entity_type_distance_acc(g, p)
146 | # if "Pk-BCE" in measures:
147 | # distance += property_key_distance_bce(g, p)
148 | # if "Pk-ACC" in measures:
149 | # distance += property_key_distance_acc(g, p)
150 | # if "Pv-prop-ACC" in measures:
151 | # distance += property_value_prop_distance_acc(g, p, weights_pks)
152 | # if "Pv-token-ACC" in measures:
153 | # distance += property_value_token_distance_acc(g, p, weights_pks)
154 | distances[i, j] = distance
155 |
156 | return distances
157 |
158 |
159 | def bipartite_matching(distances):
160 | # Max is the maximum size of ground-truth set size and prediction set size
161 | # Precision is the size of the prediction set size
162 | # Recall is the size of the ground-truth set size
163 | biadjacency_matrix = csr_matrix(distances.numpy())
164 | # Add a constant (e.g., 1) to every distance to ensure no zero values
165 | biadjacency_matrix = biadjacency_matrix + csr_matrix(
166 | torch.ones_like(distances).numpy()
167 | )
168 | # print("biadjacency_matrix:", biadjacency_matrix.todense())
169 | row_ind, col_ind = min_weight_full_bipartite_matching(
170 | biadjacency_matrix, maximize=False
171 | )
172 |
173 | # Subtract the added constant for each matched pair
174 | min_num_entity = min(biadjacency_matrix.shape[0], biadjacency_matrix.shape[1])
175 | max_num_entity = max(biadjacency_matrix.shape[0], biadjacency_matrix.shape[1])
176 | matched_distances = biadjacency_matrix[row_ind, col_ind].sum() - min_num_entity
177 | # # optimal_metric_loss = (
178 | # # (matched_distances + max_num_entity - min_num_entity) / max_num_entity
179 | # # if max_num_entity != 0
180 | # # else 0
181 | # # )
182 | # optimal_metric_loss = (matched_distances + denominator - min_num_entity) / denominator if denominator != 0 else 0
183 | # Obtain permutation
184 | permutation_ground_truth = torch.tensor(row_ind)[
185 | torch.argsort(torch.tensor(col_ind))
186 | ]
187 | permutation_prediction = torch.tensor(col_ind)
188 | return permutation_ground_truth, permutation_prediction
189 | # return optimal_metric_loss, permutation_ground_truth, permutation_prediction
190 |
191 |
192 | def compute_bipartite_matching_metrics(
193 | target: list,
194 | predicted_output: list,
195 | measures,
196 | normalization,
197 | establish_threshold=0.6,
198 | ):
199 | """Compute metrics based on bipartite matching"""
200 | assert normalization in ["Max", "Precision", "Recall"]
201 | target = deepcopy(target)
202 | predicted_output = deepcopy(predicted_output)
203 | target_set_size = len(target)
204 | predicted_output_set_size = len(predicted_output)
205 | if isinstance(target, dict):
206 | target = [entity for entity in target.values()]
207 | # print("Prediction", predicted_output)
208 | # print("Target", target)
209 | keys = set(key for entity in target + predicted_output for key in entity.keys())
210 | # print("keys:", keys)
211 | keys = list(keys)
212 |
213 | # target = remove_duplicates_and_postprocess(target)
214 | # predicted_output = remove_duplicates_and_postprocess(predicted_output)
215 |
216 | # Pad the predicted entities or ground-truth with dummy entities
217 | # to ensure that the number of entities is the same
218 | if target_set_size > predicted_output_set_size:
219 | predicted_output += [
220 | {} for _ in range(target_set_size - predicted_output_set_size)
221 | ]
222 | elif predicted_output_set_size > target_set_size:
223 | target += [{} for _ in range(predicted_output_set_size - target_set_size)]
224 |
225 | # print("Prediction", predicted_output)
226 | # print("Target", target)
227 |
228 | def get_key_index(key, keys):
229 | for i, k in enumerate(keys):
230 | if key == k:
231 | return i
232 | raise ValueError(f"The key {key} does not exists in the key list")
233 |
234 | # Create property key tensors
235 | def create_pk_tensor(entity, keys):
236 | tensor = [0] * len(keys)
237 | for key in entity.keys():
238 | tensor[get_key_index(key, keys)] = 1
239 | return torch.tensor(tensor)
240 |
241 | def create_pv_list(entity, keys):
242 | lst = [""] * len(keys)
243 | for key, value in entity.items():
244 | if not isinstance(value, str):
245 | if isinstance(value, list):
246 | value = " ".join(value)
247 | else:
248 | value = str(value)
249 | lst[get_key_index(key, keys)] = value
250 |
251 | return lst
252 |
253 | def jaccard_similarity(tokens_target, tokens_pred):
254 | # Compute the Jaccard similarity (intersection of the token set over the union)
255 | intersection_size = len(tokens_target.intersection(tokens_pred))
256 | union_size = len(tokens_target.union(tokens_pred))
257 | return intersection_size / union_size
258 |
259 | target_entities = [
260 | {"pk": create_pk_tensor(e, keys), "pv": create_pv_list(e, keys)} for e in target
261 | ]
262 | predicted_entities = [
263 | {"pk": create_pk_tensor(e, keys), "pv": create_pv_list(e, keys)}
264 | for e in predicted_output
265 | ]
266 |
267 | # Pad the predicted entities or ground-truth with dummy entities
268 | # to ensure that the number of entities is the same
269 | # if target_set_size > predicted_output_set_size:
270 | # predicted_entities += [
271 | # {"pk": torch.zeros_like(target_entities[0]["pk"]), "pv": [""]}
272 | # for _ in range(target_set_size - predicted_output_set_size)
273 | # ]
274 | # elif predicted_output_set_size > target_set_size:
275 | # target_entities += [
276 | # {"pk": torch.zeros_like(predicted_entities[0]["pk"]), "pv": [""]}
277 | # for _ in range(predicted_output_set_size - target_set_size)
278 | # ]
279 |
280 | # assume a weight of 1 for each property apart from name
281 | try:
282 | weights = torch.zeros_like(target_entities[0]["pk"])
283 | except IndexError:
284 | print(target_entities)
285 | print(target)
286 | print(predicted_output)
287 | # try:
288 | # if measures == "ExactName":
289 | # weights[get_key_index("name", keys)] = 1
290 | # elif measures == "ApproxName":
291 | # weights[get_key_index("name", keys)] = 1
292 | # elif measures == "MultiProp":
293 | # weights[get_key_index("name", keys)] = 2
294 | # for index, key in enumerate(keys):
295 | # if key != "name":
296 | # weights[index] = 1
297 | # except ValueError:
298 |
299 | if measures == "ExactName":
300 | weights[get_key_index("entity name", keys)] = 1
301 | elif measures == "ApproxName":
302 | weights[get_key_index("entity name", keys)] = 1
303 | elif measures == "MultiProp":
304 | weights[get_key_index("entity name", keys)] = 11
305 | for index, key in enumerate(keys):
306 | if key != "entity name":
307 | weights[index] = 1
308 | # pv_distances_aon = compute_distances(
309 | # target_entities, predicted_entities, ["Pv-prop-ACC"], weights
310 | # )
311 | # pv_distances_token = compute_distances(
312 | # target_entities, predicted_entities, ["Pv-token-ACC"], weights
313 | # )
314 | # pv_distances_aon_unweighted = compute_distances(
315 | # target_entities, predicted_entities, ["Pv-prop-ACC"], torch.ones_like(weights)
316 | # )
317 | # pv_distances_token_unweighted = compute_distances(
318 | # target_entities, predicted_entities, ["Pv-token-ACC"], torch.ones_like(weights)
319 | # )
320 | # pk_distances_acc = compute_distances(
321 | # target_entities, predicted_entities, ["Pk-ACC"], torch.ones_like(weights)
322 | # )
323 |
324 | # (
325 | # pv_distances_token_loss,
326 | # permutation_target,
327 | # permutation_prediction,
328 | # ) = bipartite_matching(pv_distances_token)
329 | # pv_distances_aon_loss, _, _ = bipartite_matching(pv_distances_aon)
330 | # pk_distances_acc_loss, _, _ = bipartite_matching(pk_distances_acc)
331 | # pv_distances_token_unweighted_loss, _, _ = bipartite_matching(
332 | # pv_distances_token_unweighted
333 | # )
334 | # pv_distances_aon_unweighted_loss, _, _ = bipartite_matching(
335 | # pv_distances_aon_unweighted
336 | # )
337 |
338 | if measures == "ExactName":
339 | entity_distance = compute_distances(
340 | target_entities, predicted_entities, "ExactName", weights
341 | )
342 | elif measures == "ApproxName":
343 | entity_distance = compute_distances(
344 | target_entities, predicted_entities, "ApproxName", weights
345 | )
346 | elif measures == "MultiProp":
347 | entity_distance = compute_distances(
348 | target_entities, predicted_entities, "MultiProp", weights
349 | )
350 |
351 | # if normalization == "Max":
352 | # (
353 | # permutation_target,
354 | # permutation_prediction,
355 | # ) = bipartite_matching(entity_distance)
356 | # elif normalization == "Precision":
357 | # (
358 | # permutation_target,
359 | # permutation_prediction,
360 | # ) = bipartite_matching(entity_distance)
361 | # elif normalization == "Recall":
362 | (
363 | permutation_target,
364 | permutation_prediction,
365 | ) = bipartite_matching(entity_distance)
366 |
367 | # Only establish matches that have a distance below threshold
368 | # the threshold and weight_pks is calibrated such that it does not suffice to have
369 | # a matched type property without a matched name, if the entity only contains name and type (which is often the case)
370 | established_entity_matches = []
371 | established_entity_matches_tensor = []
372 | # print("permutation_target", permutation_target)
373 | # print("permutation_prediction", permutation_prediction)
374 | # print("target", target)
375 | # print("predicted_output", predicted_output)
376 | for predicted_idx, target_idx in enumerate(permutation_target):
377 | # if entity_distance[target_idx, predicted_idx] <= establish_threshold:
378 | # established_entity_matches.append(
379 | # (target[target_idx], predicted_output[predicted_idx])
380 | # )
381 | # established_entity_matches.append(
382 | # (target[target_idx], predicted_output[predicted_idx])
383 | # )
384 | established_entity_matches.append(
385 | (target[target_idx], predicted_output[predicted_idx])
386 | )
387 | established_entity_matches_tensor.append(
388 | (target_entities[target_idx], predicted_entities[predicted_idx])
389 | )
390 |
391 | def property_value_token_distance_acc(e1, e2, weights):
392 | v1, v2 = e1["pv"], e2["pv"]
393 |
394 | # Split each property value into tokens (words)
395 | tokens_v1 = [set(value.lower().split()) for value in v1]
396 | tokens_v2 = [set(value.lower().split()) for value in v2]
397 |
398 | jaccard_similarities = []
399 | _weights = weights.clone()
400 | for i, (t1, t2) in enumerate(zip(tokens_v1, tokens_v2)):
401 | # Compute the Jaccard similarity for the token sets
402 | intersection_size = len(t1.intersection(t2))
403 | union_size = len(t1.union(t2))
404 | if union_size == 0:
405 | jaccard_sim = 0.0
406 | _weights[i] = 0.0
407 | else:
408 | jaccard_sim = intersection_size / union_size
409 | jaccard_similarities.append(jaccard_sim)
410 | similarity = (
411 | torch.tensor(jaccard_similarities) * _weights
412 | ).sum() / _weights.sum()
413 | return similarity.item()
414 |
415 | # target_entities = [
416 | # {"pk": create_pk_tensor(e, keys), "pv": create_pv_list(e, keys)} for e in target
417 | # ]
418 | # predicted_entities = [
419 | # {"pk": create_pk_tensor(e, keys), "pv": create_pv_list(e, keys)}
420 | # for e in predicted_output
421 | # ]
422 |
423 | # assume a weight of 1 for each property
424 | weights = torch.ones_like(target_entities[0]["pk"])
425 |
426 | all_similarities = []
427 | for t, p in established_entity_matches_tensor:
428 | similarity = property_value_token_distance_acc(t, p, weights)
429 | all_similarities.append(similarity)
430 |
431 | if normalization == "Max":
432 | normalized_similarity = sum(all_similarities) / max(
433 | target_set_size, predicted_output_set_size
434 | )
435 | elif normalization == "Precision":
436 | normalized_similarity = sum(all_similarities) / predicted_output_set_size
437 | elif normalization == "Recall":
438 | normalized_similarity = sum(all_similarities) / target_set_size
439 |
440 | per_property_acc_token = {key: 0.0 for key in keys}
441 | per_property_acc_aon = {key: 0.0 for key in keys}
442 | target_key_occurance = {key: 0.0 for key in keys}
443 | pred_key_occurance = {key: 0.0 for key in keys}
444 | key_matches = {key: 0.0 for key in keys}
445 | for e_target, e_pred in established_entity_matches:
446 | for key in keys:
447 | target_key_occurance[key] += key in e_target.keys()
448 | pred_key_occurance[key] += key in e_pred.keys()
449 | if key in e_target.keys() and key in e_pred.keys():
450 | key_matches[key] += 1.0
451 | tokens_target = set(e_target[key].lower().split())
452 | tokens_pred = set(e_pred[key].lower().split())
453 |
454 | jaccard_sim = jaccard_similarity(tokens_target, tokens_pred)
455 | per_property_acc_token[key] += jaccard_sim
456 |
457 | per_property_acc_aon[key] += (
458 | e_target[key].lower() == e_pred[key].lower()
459 | )
460 |
461 | # # calculate per property similarity
462 | # prop_similarities = {}
463 | #
464 | # for pk in keys:
465 | # prop_similarities[pk] = {}
466 | # all_sim = []
467 | # weights = torch.zeros_like(target_entities[0]["pk"])
468 | # weights[get_key_index(pk, keys)] = 1
469 | # for (t, p) in established_entity_matches_tensor:
470 | # similarity = property_value_token_distance_acc(t, p, weights)
471 | # all_sim.append(similarity)
472 | # prop_similarities[pk]["Max"] = sum(all_sim) / max(target_set_size, predicted_output_set_size)
473 | # prop_similarities[pk]["Precision"] = sum(all_sim) / predicted_output_set_size
474 | # prop_similarities[pk]["Recall"] = sum(all_sim) / target_set_size
475 | # if str(prop_similarities[pk]["Max"]) == "nan":
476 | # prop_similarities[pk]["Max"] = 0
477 | # if str(prop_similarities[pk]["Precision"]) == "nan":
478 | # prop_similarities[pk]["Precision"] = 0
479 | # if str(prop_similarities[pk]["Recall"]) == "nan":
480 | # prop_similarities[pk]["Recall"] = 0
481 |
482 | counts_nested = {
483 | "per_property_acc_token": per_property_acc_token,
484 | "per_property_acc_aon": per_property_acc_aon,
485 | "target_key_occurance": target_key_occurance,
486 | "pred_key_occurance": pred_key_occurance,
487 | "key_matches": key_matches,
488 | }
489 |
490 | counts = {
491 | "established_entity_matches": len(established_entity_matches),
492 | "predicted_output_entities_no_dup": len(predicted_output),
493 | "target_entities_no_dup": len(target),
494 | }
495 |
496 | # bipartite_matching_metrics = {
497 | # "normalized_similarity": normalized_similarity,
498 | # # "pv_distances_token_loss": pv_distances_token_loss,
499 | # # "pv_distances_aon_loss": pv_distances_aon_loss,
500 | # # "pk_distances_acc_loss": pk_distances_acc_loss,
501 | # # "pv_distances_token_unweighted_loss": pv_distances_token_unweighted_loss,
502 | # # "pv_distances_aon_unweighted_loss": pv_distances_aon_unweighted_loss,
503 | # }
504 | final_metrics = {
505 | "normalized_similarity": normalized_similarity,
506 | }
507 |
508 | return final_metrics, counts, counts_nested
509 |
510 |
511 | if __name__ == "__main__":
512 | # Sample usage
513 |
514 | dummy_entity = {
515 | "t": torch.tensor([0.0, 0.0]),
516 | "pk": torch.tensor([0.0, 0.0, 0.0]),
517 | "pv": ["dummy", "dummy", "dummy"],
518 | }
519 |
520 | # In this demo, we assume N=3 (max entity), M=2 (num entity type), K=3 (num property keys)
521 | # t (ground-truth) is a one-hot vector. pk (ground-truth) is a multi-hot vector.
522 | # t (prediction) is post-softmax. pk (prediction) is post-sigmoid.
523 | gt = [
524 | {
525 | "t": torch.tensor([1.0, 0.0]),
526 | "pk": torch.tensor([1.0, 1.0, 1.0]),
527 | "pv": ["XX apple", "round", "big"],
528 | },
529 | {
530 | "t": torch.tensor([0.0, 1.0]),
531 | "pk": torch.tensor([1.0, 0.0, 1.0]),
532 | "pv": ["YY banana", "long", "big"],
533 | },
534 | {
535 | "t": torch.tensor([1.0, 0.0]),
536 | "pk": torch.tensor([0.0, 1.0, 1.0]),
537 | "pv": ["ZZ grape", "round", "small"],
538 | },
539 | ]
540 | pred1 = [
541 | {
542 | "t": torch.tensor([1.0, 0.0]),
543 | "pk": torch.tensor([0.0, 1.0, 1.0]),
544 | "pv": ["ZZ grape", "round", "small"],
545 | },
546 | {
547 | "t": torch.tensor([1.0, 0.0]),
548 | "pk": torch.tensor([1.0, 1.0, 1.0]),
549 | "pv": ["XX apple", "round", "big"],
550 | },
551 | {
552 | "t": torch.tensor([0.0, 1.0]),
553 | "pk": torch.tensor([1.0, 0.0, 1.0]),
554 | "pv": ["YY banana", "long", "big"],
555 | },
556 | ]
557 |
558 | pred2 = [
559 | {
560 | "t": torch.tensor([0.999, 0.001]),
561 | "pk": torch.tensor([1.0, 1.0, 1.0]),
562 | "pv": ["YY apple", "round", "big"],
563 | },
564 | {
565 | "t": torch.tensor([0.001, 0.999]),
566 | "pk": torch.tensor([1.0, 0.1, 1.0]),
567 | "pv": ["XX banana", "long", "small"],
568 | },
569 | {
570 | "t": torch.tensor([0.8, 0.2]),
571 | "pk": torch.tensor([0.4, 0.9, 0.4]),
572 | "pv": ["ZZ grape", "round", "small"],
573 | },
574 | ]
575 |
576 | pred3 = [
577 | {
578 | "t": torch.tensor([0.7, 0.3]),
579 | "pk": torch.tensor([0.2, 0.8, 0.8]),
580 | "pv": ["XX peach", "round", "very small"],
581 | },
582 | dummy_entity,
583 | dummy_entity,
584 | ]
585 |
586 | pred_list = [pred1, pred2, pred3]
587 | for i in range(len(pred_list)):
588 | print(f"Compare GT with Pred{i + 1}:")
589 | # distances = compute_distances(gt, pred_list[i], measures=["E-CE"])
590 | # distances = compute_distances(gt, pred_list[i], measures=["Pk-BCE"])
591 | distances = compute_distances(gt, pred_list[i], measures=["E-ACC"])
592 | # distances = compute_distances(gt, pred_list[i], measures=["Pk-ACC"])
593 | # distances = compute_distances(gt, pred_list[i], measures=["Pv-prop-ACC"])
594 | # distances = compute_distances(gt, pred_list[i], measures=["Pv-token-ACC"])
595 |
596 | (
597 | optimal_metric_loss,
598 | permutation_ground_truth,
599 | permutation_prediction,
600 | ) = bipartite_matching(distances)
601 | print("optimal_metric_loss (CE loss or 1 - ACC):", optimal_metric_loss)
602 | print("permutation_ground_truth:", permutation_ground_truth)
603 | print("permutation_prediction:", permutation_prediction)
604 | print("-----")
605 |
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/Structured-Entity-Extraction/8e3d0fbb009874b5ba35a450de1ff37995b8102e/model/__init__.py
--------------------------------------------------------------------------------
/model/t5_with_t5decoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from transformers import T5Config, T5ForConditionalGeneration
4 |
5 |
6 | class T5_with_T5Decoder(nn.Module):
7 | def __init__(self, pretrained_t5_name, tokenizer, pre_train=True):
8 | super().__init__()
9 |
10 | if pre_train:
11 | print("Using Pre-trained T5 model")
12 | self.t5_model = T5ForConditionalGeneration.from_pretrained(
13 | pretrained_t5_name
14 | )
15 | else:
16 | print("Using Randomly Initialized T5 model")
17 | self.t5_model = T5ForConditionalGeneration(T5Config())
18 | print("Window Size:", self.t5_model.config.d_model)
19 | self.tokenizer = tokenizer
20 | self.d_model = self.t5_model.config.d_model
21 |
22 | def forward(self, input_ids=None, attention_mask=None):
23 |
24 | original_text_embeddings = self.t5_model.shared(input_ids)
25 | # Encode the input text
26 | encoder_outputs = self.t5_model.encoder(
27 | input_ids=input_ids,
28 | attention_mask=attention_mask,
29 | )
30 |
31 | return original_text_embeddings, encoder_outputs
32 |
33 | def decode_at(self, encoder_outputs, decoder_input_ids, target_sequence_length):
34 | all_logits = []
35 | past_key_values = None
36 |
37 | for i in range(target_sequence_length):
38 | outputs = self.t5_model(
39 | input_ids=None,
40 | attention_mask=None,
41 | decoder_input_ids=decoder_input_ids,
42 | encoder_outputs=encoder_outputs,
43 | past_key_values=past_key_values,
44 | use_cache=True,
45 | return_dict=True,
46 | )
47 |
48 | next_token_logits = outputs.logits[:, -1]
49 | next_tokens = next_token_logits.argmax(-1, keepdim=True)
50 | all_logits.append(next_token_logits.unsqueeze(1))
51 |
52 | # Update decoder_input_ids for the next iteration
53 | decoder_input_ids = torch.cat([decoder_input_ids, next_tokens], dim=-1)
54 | past_key_values = (
55 | outputs.past_key_values
56 | ) # Store past key values for the next iteration
57 |
58 | return torch.cat(all_logits, dim=1)
59 |
--------------------------------------------------------------------------------
/model/t5_with_t5decoder_emb.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from transformers import T5Config, T5ForConditionalGeneration
4 |
5 |
6 | class T5_with_T5Decoder(nn.Module):
7 | def __init__(self, pretrained_t5_name, tokenizer, pre_train=True):
8 | super().__init__()
9 |
10 | if pre_train:
11 | print("Using Pre-trained T5 model")
12 | self.t5_model = T5ForConditionalGeneration.from_pretrained(
13 | pretrained_t5_name
14 | )
15 | else:
16 | print("Using Randomly Initialized T5 model")
17 | self.t5_model = T5ForConditionalGeneration(T5Config())
18 | print("Window Size:", self.t5_model.config.d_model)
19 | self.tokenizer = tokenizer
20 | self.d_model = self.t5_model.config.d_model
21 |
22 | def forward(
23 | self,
24 | input_ids=None,
25 | attention_mask=None,
26 | decoder_input_ids=None,
27 | decoder_attention_mask=None,
28 | output_hidden_states=False,
29 | return_dict=True,
30 | ):
31 |
32 | original_text_embeddings = self.t5_model.shared(input_ids)
33 | # Encode the input text
34 | encoder_outputs = self.t5_model.encoder(
35 | input_ids=input_ids,
36 | attention_mask=attention_mask,
37 | # return_dict=True,
38 | # output_hidden_states=output_hidden_states
39 | )
40 |
41 | return original_text_embeddings, encoder_outputs
42 |
43 | def decode_at_emb(self, encoder_outputs, target_sequence_length=20):
44 | device = encoder_outputs.last_hidden_state.device
45 | num_position = encoder_outputs.last_hidden_state.size(0)
46 |
47 | # Initialize start tokens
48 | start_token_id = (
49 | self.tokenizer.bos_token_id
50 | if self.tokenizer.bos_token_id is not None
51 | else self.tokenizer.pad_token_id
52 | )
53 | start_tokens = torch.full(
54 | (num_position, 1),
55 | start_token_id,
56 | dtype=torch.long,
57 | device=device,
58 | )
59 |
60 | # Get the embeddings of the start tokens
61 | decoder_inputs_embeds = self.t5_model.get_input_embeddings()(start_tokens)
62 |
63 | # Initialize the decoder's attention mask with a single 1 for the start token
64 | decoder_attention_mask = torch.ones(
65 | (num_position, 1), dtype=torch.long, device=device
66 | )
67 |
68 | past_key_values = None
69 | return_embeds = None
70 |
71 | for _ in range(target_sequence_length):
72 | outputs = self.t5_model(
73 | encoder_outputs=encoder_outputs,
74 | decoder_inputs_embeds=decoder_inputs_embeds,
75 | decoder_attention_mask=decoder_attention_mask,
76 | past_key_values=past_key_values,
77 | use_cache=True,
78 | output_hidden_states=True,
79 | )
80 |
81 | # Extract the last hidden state (embedding) of the last token
82 | last_embedding = outputs.decoder_hidden_states[-1][
83 | :, -1:, :
84 | ] # (b, 1, d_model)
85 | return_embeds = (
86 | last_embedding
87 | if return_embeds is None
88 | else torch.cat([return_embeds, last_embedding], dim=1)
89 | )
90 |
91 | # if use past_key_values, only need to input the last decoder_inputs_embs
92 | decoder_inputs_embeds = last_embedding
93 |
94 | # Update the decoder's attention mask
95 | decoder_attention_mask = torch.cat(
96 | [
97 | decoder_attention_mask,
98 | torch.ones((num_position, 1), dtype=torch.long, device=device),
99 | ],
100 | dim=1,
101 | )
102 |
103 | # Store past key values for the next iteration
104 | past_key_values = outputs.past_key_values
105 |
106 | return return_embeds
107 |
--------------------------------------------------------------------------------
/model/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/Structured-Entity-Extraction/8e3d0fbb009874b5ba35a450de1ff37995b8102e/model/utils/__init__.py
--------------------------------------------------------------------------------
/model/utils/dist_util.py:
--------------------------------------------------------------------------------
1 | """
2 | Helpers for distributed training.
3 | """
4 |
5 | import io
6 | import os
7 | import socket
8 |
9 | import blobfile as bf
10 | import torch as th
11 | import torch.distributed as dist
12 |
13 | # Change this to reflect your cluster layout.
14 |
15 |
16 | def setup_dist():
17 | """
18 | Setup a distributed process group.
19 | """
20 | if dist.is_initialized():
21 | return
22 |
23 | backend = "gloo" if not th.cuda.is_available() else "nccl"
24 |
25 | if backend == "gloo":
26 | hostname = "localhost"
27 | else:
28 | hostname = socket.gethostbyname(socket.getfqdn())
29 |
30 | if os.environ.get("LOCAL_RANK") is None:
31 | os.environ["MASTER_ADDR"] = hostname
32 | os.environ["RANK"] = str(0)
33 | os.environ["WORLD_SIZE"] = str(1)
34 | port = _find_free_port()
35 | os.environ["MASTER_PORT"] = str(port)
36 | os.environ["LOCAL_RANK"] = str(0)
37 |
38 | dist.init_process_group(backend=backend, init_method="env://")
39 |
40 | if th.cuda.is_available(): # This clears remaining caches in GPU 0
41 | th.cuda.set_device(dev())
42 | th.cuda.empty_cache()
43 |
44 |
45 | def dev():
46 | """
47 | Get the device to use for torch.distributed.
48 | """
49 | if th.cuda.is_available():
50 | return th.device(f"cuda:{os.environ['LOCAL_RANK']}")
51 | return th.device("cpu")
52 |
53 |
54 | def load_state_dict(path, **kwargs):
55 | """
56 | Load a PyTorch file.
57 | """
58 | # if int(os.environ['LOCAL_RANK']) == 0:
59 | with bf.BlobFile(path, "rb") as f:
60 | data = f.read()
61 | return th.load(io.BytesIO(data), **kwargs)
62 |
63 |
64 | def sync_params(params):
65 | """
66 | Synchronize a sequence of Tensors across ranks from rank 0.
67 | """
68 | for p in params:
69 | with th.no_grad():
70 | dist.broadcast(p, 0)
71 |
72 |
73 | def _find_free_port():
74 | try:
75 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
76 | s.bind(("127.0.0.1", 0)) # Bind to the local interface only
77 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
78 | return s.getsockname()[1]
79 | finally:
80 | s.close()
81 |
--------------------------------------------------------------------------------
/model/utils/fp16_util.py:
--------------------------------------------------------------------------------
1 | """
2 | Helpers to train with 16-bit precision.
3 | """
4 |
5 | import torch.nn as nn
6 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
7 |
8 |
9 | def convert_module_to_f16(l):
10 | """
11 | Convert primitive modules to float16.
12 | """
13 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
14 | l.weight.data = l.weight.data.half()
15 | l.bias.data = l.bias.data.half()
16 |
17 |
18 | def convert_module_to_f32(l):
19 | """
20 | Convert primitive modules to float32, undoing convert_module_to_f16().
21 | """
22 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
23 | l.weight.data = l.weight.data.float()
24 | l.bias.data = l.bias.data.float()
25 |
26 |
27 | def make_master_params(model_params):
28 | """
29 | Copy model parameters into a (differently-shaped) list of full-precision
30 | parameters.
31 | """
32 | master_params = _flatten_dense_tensors(
33 | [param.detach().float() for param in model_params]
34 | )
35 | master_params = nn.Parameter(master_params)
36 | master_params.requires_grad = True
37 | return [master_params]
38 |
39 |
40 | def model_grads_to_master_grads(model_params, master_params):
41 | """
42 | Copy the gradients from the model parameters into the master parameters
43 | from make_master_params().
44 | """
45 | master_params[0].grad = _flatten_dense_tensors(
46 | [param.grad.data.detach().float() for param in model_params]
47 | )
48 |
49 |
50 | def master_params_to_model_params(model_params, master_params):
51 | """
52 | Copy the master parameter data back into the model parameters.
53 | """
54 | # Without copying to a list, if a generator is passed, this will
55 | # silently not copy any parameters.
56 | model_params = list(model_params)
57 |
58 | for param, master_param in zip(
59 | model_params, unflatten_master_params(model_params, master_params)
60 | ):
61 | param.detach().copy_(master_param)
62 |
63 |
64 | def unflatten_master_params(model_params, master_params):
65 | """
66 | Unflatten the master parameters to look like model_params.
67 | """
68 | return _unflatten_dense_tensors(master_params[0].detach(), model_params)
69 |
70 |
71 | def zero_grad(model_params):
72 | for param in model_params:
73 | # Taken from https://pytorch.org/docs/stable/_modules/torch/
74 | # optim/optimizer.html#Optimizer.add_param_group
75 | if param.grad is not None:
76 | param.grad.detach_()
77 | param.grad.zero_()
78 |
--------------------------------------------------------------------------------
/model/utils/logger.py:
--------------------------------------------------------------------------------
1 | """
2 | Logger copied from OpenAI baselines to avoid extra RL-based dependencies:
3 | https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/logger.py
4 | """
5 |
6 | import datetime
7 | import json
8 | import os
9 | import os.path as osp
10 | # import shutil
11 | import sys
12 | import tempfile
13 | import time
14 | import warnings
15 | from collections import defaultdict
16 | from contextlib import contextmanager
17 |
18 | import wandb
19 |
20 | DEBUG = 10
21 | INFO = 20
22 | WARN = 30
23 | ERROR = 40
24 |
25 | DISABLED = 50
26 |
27 |
28 | class KVWriter(object):
29 | def writekvs(self, kvs):
30 | raise NotImplementedError
31 |
32 |
33 | class SeqWriter(object):
34 | def writeseq(self, seq):
35 | raise NotImplementedError
36 |
37 |
38 | class HumanOutputFormat(KVWriter, SeqWriter):
39 | def __init__(self, filename_or_file):
40 | if isinstance(filename_or_file, str):
41 | self.file = open(filename_or_file, "wt")
42 | self.own_file = True
43 | else:
44 | assert hasattr(filename_or_file, "read"), (
45 | "expected file or str, got %s" % filename_or_file
46 | )
47 | self.file = filename_or_file
48 | self.own_file = False
49 |
50 | def writekvs(self, kvs):
51 | # Create strings for printing
52 | key2str = {}
53 | for key, val in sorted(kvs.items()):
54 | if hasattr(val, "__float__"):
55 | valstr = "%-8.3g" % val
56 | else:
57 | valstr = str(val)
58 | key2str[self._truncate(key)] = self._truncate(valstr)
59 |
60 | # Find max widths
61 | if len(key2str) == 0:
62 | print("WARNING: tried to write empty key-value dict")
63 | return
64 | else:
65 | keywidth = max(map(len, key2str.keys()))
66 | valwidth = max(map(len, key2str.values()))
67 |
68 | # Write out the data
69 | dashes = "-" * (keywidth + valwidth + 7)
70 | lines = [dashes]
71 | for key, val in sorted(key2str.items(), key=lambda kv: kv[0].lower()):
72 | lines.append(
73 | "| %s%s | %s%s |"
74 | % (key, " " * (keywidth - len(key)), val, " " * (valwidth - len(val)))
75 | )
76 | lines.append(dashes)
77 | self.file.write("\n".join(lines) + "\n")
78 |
79 | # Flush the output to the file
80 | self.file.flush()
81 |
82 | def _truncate(self, s):
83 | maxlen = 30
84 | return s[: maxlen - 3] + "..." if len(s) > maxlen else s
85 |
86 | def writeseq(self, seq):
87 | seq = list(seq)
88 | for i, elem in enumerate(seq):
89 | self.file.write(elem)
90 | if i < len(seq) - 1: # add space unless this is the last one
91 | self.file.write(" ")
92 | self.file.write("\n")
93 | self.file.flush()
94 |
95 | def close(self):
96 | if self.own_file:
97 | self.file.close()
98 |
99 |
100 | class JSONOutputFormat(KVWriter):
101 | def __init__(self, filename):
102 | self.file = open(filename, "wt")
103 |
104 | def writekvs(self, kvs):
105 | for k, v in sorted(kvs.items()):
106 | if hasattr(v, "dtype"):
107 | kvs[k] = float(v)
108 | self.file.write(json.dumps(kvs) + "\n")
109 | self.file.flush()
110 |
111 | def close(self):
112 | self.file.close()
113 |
114 |
115 | class CSVOutputFormat(KVWriter):
116 | def __init__(self, filename):
117 | self.file = open(filename, "w+t")
118 | self.keys = []
119 | self.sep = ","
120 |
121 | def writekvs(self, kvs):
122 | # Add our current row to the history
123 | extra_keys = list(kvs.keys() - self.keys)
124 | extra_keys.sort()
125 | if extra_keys:
126 | self.keys.extend(extra_keys)
127 | self.file.seek(0)
128 | lines = self.file.readlines()
129 | self.file.seek(0)
130 | for i, k in enumerate(self.keys):
131 | if i > 0:
132 | self.file.write(",")
133 | self.file.write(k)
134 | self.file.write("\n")
135 | for line in lines[1:]:
136 | self.file.write(line[:-1])
137 | self.file.write(self.sep * len(extra_keys))
138 | self.file.write("\n")
139 | for i, k in enumerate(self.keys):
140 | if i > 0:
141 | self.file.write(",")
142 | v = kvs.get(k)
143 | if v is not None:
144 | self.file.write(str(v))
145 | self.file.write("\n")
146 | self.file.flush()
147 |
148 | def close(self):
149 | self.file.close()
150 |
151 |
152 | class TensorBoardOutputFormat(KVWriter):
153 | """
154 | Dumps key/value pairs into TensorBoard's numeric format.
155 | """
156 |
157 | def __init__(self, dir):
158 | os.makedirs(dir, exist_ok=True)
159 | self.dir = dir
160 | self.step = 1
161 | prefix = "events"
162 | path = osp.join(osp.abspath(dir), prefix)
163 | import tensorflow as tf
164 | from tensorflow.core.util import event_pb2
165 | from tensorflow.python import pywrap_tensorflow
166 | from tensorflow.python.util import compat
167 |
168 | self.tf = tf
169 | self.event_pb2 = event_pb2
170 | self.pywrap_tensorflow = pywrap_tensorflow
171 | self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path))
172 |
173 | def writekvs(self, kvs):
174 | def summary_val(k, v):
175 | kwargs = {"tag": k, "simple_value": float(v)}
176 | return self.tf.Summary.Value(**kwargs)
177 |
178 | summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()])
179 | event = self.event_pb2.Event(wall_time=time.time(), summary=summary)
180 | event.step = (
181 | self.step
182 | ) # is there any reason why you'd want to specify the step?
183 | self.writer.WriteEvent(event)
184 | self.writer.Flush()
185 | self.step += 1
186 |
187 | def close(self):
188 | if self.writer:
189 | self.writer.Close()
190 | self.writer = None
191 |
192 |
193 | def make_output_format(format, ev_dir, log_suffix=""):
194 | os.makedirs(ev_dir, exist_ok=True)
195 | if format == "stdout":
196 | return HumanOutputFormat(sys.stdout)
197 | elif format == "log":
198 | return HumanOutputFormat(osp.join(ev_dir, "log%s.txt" % log_suffix))
199 | elif format == "json":
200 | return JSONOutputFormat(osp.join(ev_dir, "progress%s.json" % log_suffix))
201 | elif format == "csv":
202 | return CSVOutputFormat(osp.join(ev_dir, "progress%s.csv" % log_suffix))
203 | elif format == "tensorboard":
204 | return TensorBoardOutputFormat(osp.join(ev_dir, "tb%s" % log_suffix))
205 | else:
206 | raise ValueError("Unknown format specified: %s" % (format,))
207 |
208 |
209 | # ================================================================
210 | # API
211 | # ================================================================
212 |
213 |
214 | def logkv(key, val):
215 | """
216 | Log a value of some diagnostic
217 | Call this once for each diagnostic quantity, each iteration
218 | If called many times, last value will be used.
219 | """
220 | get_current().logkv(key, val)
221 |
222 |
223 | def logkv_mean(key, val):
224 | """
225 | The same as logkv(), but if called many times, values averaged.
226 | """
227 | get_current().logkv_mean(key, val)
228 |
229 |
230 | def logkvs(d):
231 | """
232 | Log a dictionary of key-value pairs
233 | """
234 | for k, v in d.items():
235 | logkv(k, v)
236 |
237 |
238 | def dumpkvs():
239 | """
240 | Write all of the diagnostics from the current iteration
241 | """
242 | return get_current().dumpkvs()
243 |
244 |
245 | def getkvs():
246 | return get_current().name2val
247 |
248 |
249 | def log(*args, level=INFO):
250 | """
251 | Write the sequence of args, with no separators,
252 | to the console and output files (if you've configured an output file).
253 | """
254 | get_current().log(*args, level=level)
255 |
256 |
257 | def debug(*args):
258 | log(*args, level=DEBUG)
259 |
260 |
261 | def info(*args):
262 | log(*args, level=INFO)
263 |
264 |
265 | def warn(*args):
266 | log(*args, level=WARN)
267 |
268 |
269 | def error(*args):
270 | log(*args, level=ERROR)
271 |
272 |
273 | def set_level(level):
274 | """
275 | Set logging threshold on current logger.
276 | """
277 | get_current().set_level(level)
278 |
279 |
280 | def set_comm(comm):
281 | get_current().set_comm(comm)
282 |
283 |
284 | def get_dir():
285 | """
286 | Get directory that log files are being written to.
287 | will be None if there is no output directory (i.e., if you didn't call start)
288 | """
289 | return get_current().get_dir()
290 |
291 |
292 | record_tabular = logkv
293 | dump_tabular = dumpkvs
294 |
295 |
296 | @contextmanager
297 | def profile_kv(scopename):
298 | logkey = "wait_" + scopename
299 | tstart = time.time()
300 | try:
301 | yield
302 | finally:
303 | get_current().name2val[logkey] += time.time() - tstart
304 |
305 |
306 | def profile(n):
307 | """
308 | Usage:
309 | @profile("my_func")
310 | def my_func(): code
311 | """
312 |
313 | def decorator_with_name(func):
314 | def func_wrapper(*args, **kwargs):
315 | with profile_kv(n):
316 | return func(*args, **kwargs)
317 |
318 | return func_wrapper
319 |
320 | return decorator_with_name
321 |
322 |
323 | # ================================================================
324 | # Backend
325 | # ================================================================
326 |
327 |
328 | def get_current():
329 | if Logger.CURRENT is None:
330 | _configure_default_logger()
331 |
332 | return Logger.CURRENT
333 |
334 |
335 | class Logger(object):
336 | DEFAULT = None # A logger with no output files. (See right below class definition)
337 | # So that you can still log to the terminal without setting up any output files
338 | CURRENT = None # Current logger being used by the free functions above
339 |
340 | def __init__(self, dir, output_formats, comm=None):
341 | self.name2val = defaultdict(float) # values this iteration
342 | self.name2cnt = defaultdict(int)
343 | self.level = INFO
344 | self.dir = dir
345 | self.output_formats = output_formats
346 | self.comm = comm
347 |
348 | # Logging API, forwarded
349 | # ----------------------------------------
350 | def logkv(self, key, val):
351 | self.name2val[key] = val
352 |
353 | def logkv_mean(self, key, val):
354 | oldval, cnt = self.name2val[key], self.name2cnt[key]
355 | self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1)
356 | self.name2cnt[key] = cnt + 1
357 |
358 | def dumpkvs(self, prefix=None):
359 | if self.comm is None:
360 | d = self.name2val
361 | else:
362 | d = mpi_weighted_mean(
363 | self.comm,
364 | {
365 | name: (val, self.name2cnt.get(name, 1))
366 | for (name, val) in self.name2val.items()
367 | },
368 | )
369 | if self.comm.rank != 0:
370 | d["dummy"] = 1 # so we don't get a warning about empty dict
371 | # LISA
372 | out = d.copy() # Return the dict for unit testing purposes
373 | if int(os.environ["LOCAL_RANK"]) == 0:
374 | wandb.log({**d})
375 | for fmt in self.output_formats:
376 | if isinstance(fmt, KVWriter):
377 | fmt.writekvs(d)
378 | self.name2val.clear()
379 | self.name2cnt.clear()
380 | return out
381 |
382 | def log(self, *args, level=INFO):
383 | if self.level <= level:
384 | self._do_log(args)
385 |
386 | # Configuration
387 | # ----------------------------------------
388 | def set_level(self, level):
389 | self.level = level
390 |
391 | def set_comm(self, comm):
392 | self.comm = comm
393 |
394 | def get_dir(self):
395 | return self.dir
396 |
397 | def close(self):
398 | for fmt in self.output_formats:
399 | fmt.close()
400 |
401 | # Misc
402 | # ----------------------------------------
403 | def _do_log(self, args):
404 | for fmt in self.output_formats:
405 | if isinstance(fmt, SeqWriter):
406 | fmt.writeseq(map(str, args))
407 |
408 |
409 | def get_rank_without_mpi_import():
410 | # check environment variables here instead of importing mpi4py
411 | # to avoid calling MPI_Init() when this module is imported
412 | for varname in ["PMI_RANK", "OMPI_COMM_WORLD_RANK"]:
413 | if varname in os.environ:
414 | return int(os.environ[varname])
415 | return 0
416 |
417 |
418 | def mpi_weighted_mean(comm, local_name2valcount):
419 | """
420 | Copied from: https://github.com/openai/baselines/blob/
421 | ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/mpi_util.py#L110
422 | Perform a weighted average over dicts that are each on a different node
423 | Input: local_name2valcount: dict mapping key -> (value, count)
424 | Returns: key -> mean
425 | """
426 | all_name2valcount = comm.gather(local_name2valcount)
427 | if comm.rank == 0:
428 | name2sum = defaultdict(float)
429 | name2count = defaultdict(float)
430 | for n2vc in all_name2valcount:
431 | for name, (val, count) in n2vc.items():
432 | try:
433 | val = float(val)
434 | except ValueError:
435 | if comm.rank == 0:
436 | warnings.warn(
437 | "WARNING: tried to compute mean on non-float {}={}".format(
438 | name, val
439 | )
440 | )
441 | else:
442 | name2sum[name] += val * count
443 | name2count[name] += count
444 | return {name: name2sum[name] / name2count[name] for name in name2sum}
445 | else:
446 | return {}
447 |
448 |
449 | def configure(dir=None, format_strs=None, comm=None, log_suffix=""):
450 | """
451 | If comm is provided, average all numerical stats across that comm
452 | """
453 | if dir is None:
454 | dir = os.getenv("OPENAI_LOGDIR")
455 | if dir is None:
456 | dir = osp.join(
457 | tempfile.gettempdir(),
458 | datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f"),
459 | )
460 | assert isinstance(dir, str)
461 | dir = os.path.expanduser(dir)
462 | os.makedirs(os.path.expanduser(dir), exist_ok=True)
463 |
464 | rank = get_rank_without_mpi_import()
465 | if rank > 0:
466 | log_suffix = log_suffix + "-rank%03i" % rank
467 |
468 | if format_strs is None:
469 | if rank == 0:
470 | format_strs = os.getenv("OPENAI_LOG_FORMAT", "stdout,log,csv").split(",")
471 | else:
472 | format_strs = os.getenv("OPENAI_LOG_FORMAT_MPI", "log").split(",")
473 | format_strs = filter(None, format_strs)
474 | output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs]
475 |
476 | Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm)
477 | if output_formats:
478 | log("Logging to %s" % dir)
479 |
480 |
481 | def _configure_default_logger():
482 | configure()
483 | Logger.DEFAULT = Logger.CURRENT
484 |
485 |
486 | def reset():
487 | if Logger.CURRENT is not Logger.DEFAULT:
488 | Logger.CURRENT.close()
489 | Logger.CURRENT = Logger.DEFAULT
490 | log("Reset logger")
491 |
492 |
493 | @contextmanager
494 | def scoped_configure(dir=None, format_strs=None, comm=None):
495 | prevlogger = Logger.CURRENT
496 | configure(dir=dir, format_strs=format_strs, comm=comm)
497 | try:
498 | yield
499 | finally:
500 | Logger.CURRENT.close()
501 | Logger.CURRENT = prevlogger
502 |
--------------------------------------------------------------------------------
/model/utils/losses.py:
--------------------------------------------------------------------------------
1 | """
2 | Helpers for various likelihood-based losses. These are ported from the original
3 | Ho et al. diffusion models codebase:
4 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py
5 | """
6 |
7 | import numpy as np
8 | import torch as th
9 |
10 |
11 | def normal_kl(mean1, logvar1, mean2, logvar2):
12 | """
13 | Compute the KL divergence between two gaussians.
14 |
15 | Shapes are automatically broadcasted, so batches can be compared to
16 | scalars, among other use cases.
17 | """
18 | tensor = None
19 | for obj in (mean1, logvar1, mean2, logvar2):
20 | if isinstance(obj, th.Tensor):
21 | tensor = obj
22 | break
23 | assert tensor is not None, "at least one argument must be a Tensor"
24 |
25 | # Force variances to be Tensors. Broadcasting helps convert scalars to
26 | # Tensors, but it does not work for th.exp().
27 | logvar1, logvar2 = [
28 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
29 | for x in (logvar1, logvar2)
30 | ]
31 |
32 | # print(logvar2.shape)
33 | # temp1 = 0.5 * (-1.0 + logvar2 - logvar1 + th.exp(logvar1 - logvar2))
34 | # print(f'const = {temp1.mean()}, coef={(th.exp(-logvar2) * 0.5).mean()},
35 | # mse={((mean1 - mean2) ** 2).mean().item()}')
36 |
37 | return 0.5 * (
38 | -1.0
39 | + logvar2
40 | - logvar1
41 | + th.exp(logvar1 - logvar2)
42 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
43 | )
44 |
45 |
46 | def approx_standard_normal_cdf(x):
47 | """
48 | A fast approximation of the cumulative distribution function of the
49 | standard normal.
50 | """
51 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
52 |
53 |
54 | def discretized_gaussian_log_likelihood(x, *, means, log_scales):
55 | """
56 | Compute the log-likelihood of a Gaussian distribution discretizing to a
57 | given image.
58 |
59 | :param x: the target images. It is assumed that this was uint8 values,
60 | rescaled to the range [-1, 1].
61 | :param means: the Gaussian mean Tensor.
62 | :param log_scales: the Gaussian log stddev Tensor.
63 | :return: a tensor like x of log probabilities (in nats).
64 | """
65 | assert x.shape == means.shape == log_scales.shape
66 | centered_x = x - means
67 | inv_stdv = th.exp(-log_scales)
68 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
69 | cdf_plus = approx_standard_normal_cdf(plus_in)
70 | min_in = inv_stdv * (centered_x - 1.0 / 255.0)
71 | cdf_min = approx_standard_normal_cdf(min_in)
72 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
73 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
74 | cdf_delta = cdf_plus - cdf_min
75 | log_probs = th.where(
76 | x < -0.999,
77 | log_cdf_plus,
78 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
79 | )
80 | assert log_probs.shape == x.shape
81 | return log_probs
82 |
83 |
84 | def gaussian_density(x, *, means, log_scales):
85 | from torch.distributions import Normal
86 |
87 | normal_dist = Normal(means, log_scales.exp())
88 | logp = normal_dist.log_prob(x)
89 | return logp
90 |
91 |
92 | def discretized_text_log_likelihood(x, *, means, log_scales):
93 | """
94 | Compute the log-likelihood of a Gaussian distribution discretizing to a
95 | given image.
96 |
97 | :param x: the target images. It is assumed that this was uint8 values,
98 | rescaled to the range [-1, 1].
99 | :param means: the Gaussian mean Tensor.
100 | :param log_scales: the Gaussian log stddev Tensor.
101 | :return: a tensor like x of log probabilities (in nats).
102 | """
103 | print(x.shape, means.shape)
104 | # assert x.shape == means.shape == log_scales.shape
105 | print(x, means)
106 | centered_x = x - means
107 | inv_stdv = th.exp(-log_scales)
108 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
109 | cdf_plus = approx_standard_normal_cdf(plus_in)
110 | min_in = inv_stdv * (centered_x - 1.0 / 255.0)
111 | cdf_min = approx_standard_normal_cdf(min_in)
112 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
113 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
114 | cdf_delta = cdf_plus - cdf_min
115 | log_probs = th.where(
116 | x < -0.999,
117 | log_cdf_plus,
118 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
119 | )
120 | assert log_probs.shape == x.shape
121 | return log_probs
122 |
--------------------------------------------------------------------------------
/model/utils/nn.py:
--------------------------------------------------------------------------------
1 | """
2 | Various utilities for neural networks.
3 | """
4 |
5 | import math
6 |
7 | import torch as th
8 | import torch.nn as nn
9 |
10 |
11 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
12 | class SiLU(nn.Module):
13 | def forward(self, x):
14 | return x * th.sigmoid(x)
15 |
16 |
17 | class GroupNorm32(nn.GroupNorm):
18 | def forward(self, x):
19 | return super().forward(x.float()).type(x.dtype)
20 |
21 |
22 | def linear(*args, **kwargs):
23 | """
24 | Create a linear module.
25 | """
26 | return nn.Linear(*args, **kwargs)
27 |
28 |
29 | def avg_pool_nd(dims, *args, **kwargs):
30 | """
31 | Create a 1D, 2D, or 3D average pooling module.
32 | """
33 | if dims == 1:
34 | return nn.AvgPool1d(*args, **kwargs)
35 | elif dims == 2:
36 | return nn.AvgPool2d(*args, **kwargs)
37 | elif dims == 3:
38 | return nn.AvgPool3d(*args, **kwargs)
39 | raise ValueError(f"unsupported dimensions: {dims}")
40 |
41 |
42 | def update_ema(target_params, source_params, rate=0.99):
43 | """
44 | Update target parameters to be closer to those of source parameters using
45 | an exponential moving average.
46 |
47 | :param target_params: the target parameter sequence.
48 | :param source_params: the source parameter sequence.
49 | :param rate: the EMA rate (closer to 1 means slower).
50 | """
51 | for targ, src in zip(target_params, source_params):
52 | targ.detach().mul_(rate).add_(src, alpha=1 - rate)
53 |
54 |
55 | def zero_module(module):
56 | """
57 | Zero out the parameters of a module and return it.
58 | """
59 | for p in module.parameters():
60 | p.detach().zero_()
61 | return module
62 |
63 |
64 | def scale_module(module, scale):
65 | """
66 | Scale the parameters of a module and return it.
67 | """
68 | for p in module.parameters():
69 | p.detach().mul_(scale)
70 | return module
71 |
72 |
73 | def mean_flat(tensor):
74 | """
75 | Take the mean over all non-batch dimensions.
76 | """
77 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
78 |
79 |
80 | def normalization(channels):
81 | """
82 | Make a standard normalization layer.
83 |
84 | :param channels: number of input channels.
85 | :return: an nn.Module for normalization.
86 | """
87 | return GroupNorm32(32, channels)
88 |
89 |
90 | def timestep_embedding(timesteps, dim, max_period=10000):
91 | """
92 | Create sinusoidal timestep embeddings.
93 |
94 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
95 | These may be fractional.
96 | :param dim: the dimension of the output.
97 | :param max_period: controls the minimum frequency of the embeddings.
98 | :return: an [N x dim] Tensor of positional embeddings.
99 | """
100 | half = dim // 2
101 | freqs = th.exp(
102 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half
103 | ).to(device=timesteps.device)
104 | args = timesteps[:, None].float() * freqs[None]
105 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
106 | if dim % 2:
107 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
108 | return embedding
109 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==0.30.1
2 | beautifulsoup4==4.12.3
3 | black==24.4.2
4 | flake8==7.0.0
5 | huggingface-hub==0.27.1
6 | isort==5.13.2
7 | matplotlib==3.9.0
8 | numpy==1.26.4
9 | pandas==2.2.3
10 | peft==0.11.1
11 | regex==2024.11.6
12 | requests==2.32.2
13 | sentencepiece==0.2.0
14 | torch==2.5.1
15 | tokenizers==0.20.2
16 | torch==2.3.0
17 | tqdm==4.66.4
18 | transformers @ git+https://github.com/huggingface/transformers@bdb9106f247fca48a71eb384be25dbbd29b065a8
19 | triton==3.1.0
20 | urllib3==2.2.2
21 | wandb==0.17.0
22 | Wikidata==0.7.0
23 | wikipedia==1.4.0
24 | xformers==0.0.26.post1
25 |
--------------------------------------------------------------------------------
/trainer/__init__.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 |
4 | class Trainer(ABC):
5 | @abstractmethod
6 | def train(
7 | self,
8 | save_path,
9 | model,
10 | train_dataloader,
11 | val_dataloader,
12 | optimizer,
13 | scheduler,
14 | epochs,
15 | **kwargs
16 | ):
17 | # training process
18 | raise NotImplementedError
19 |
20 | @abstractmethod
21 | def evaluate(self, model, test_dataloader, tokenizer, **kwargs):
22 | # testing process
23 | raise NotImplementedError
24 |
--------------------------------------------------------------------------------
/trainer/trainer_musee.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 |
4 | import torch
5 | import torch.nn.functional as F
6 | import wandb
7 | from torch import nn
8 | from torch.nn.functional import softmax
9 | from transformers import T5Config, T5ForConditionalGeneration
10 |
11 | from . import Trainer
12 |
13 |
14 | def decode_logits(logits, tokenizer):
15 | probs = softmax(logits, dim=-1)
16 |
17 | # Get the most likely token IDs
18 | predicted_ids = torch.argmax(probs, dim=-1)
19 | print("predicted_ids:", predicted_ids)
20 |
21 | # Decode the token IDs to tokens
22 | decoded_tokens = []
23 | for i in range(predicted_ids.shape[0]):
24 | decoded_sequence = tokenizer.decode(predicted_ids[i], skip_special_tokens=True)
25 | decoded_tokens.append(decoded_sequence.split())
26 |
27 | return decoded_tokens
28 |
29 |
30 | def replace_with_closest_embedding(
31 | predict_pk_ids, added_ent_type_tokens, added_pk_tokens, model, device
32 | ):
33 | # Getting the embeddings from the model and moving them to the specified device
34 | embeddings = model.t5_model.shared.weight.to(device)
35 |
36 | def find_closest_token(target_token, embeddings, allowed_tokens, device):
37 | target_embedding = (
38 | embeddings[target_token].unsqueeze(0).to(device)
39 | ) # Get the embedding of the target token
40 | allowed_embeddings = embeddings[allowed_tokens].to(
41 | device
42 | ) # Get embeddings of allowed tokens
43 | distances = torch.norm(target_embedding - allowed_embeddings, dim=1)
44 | closest_idx = distances.argmin()
45 | closest_token = allowed_tokens[closest_idx]
46 |
47 | return closest_token
48 |
49 | for idx, sequence in enumerate(predict_pk_ids):
50 | for token_idx, token in enumerate(sequence):
51 | if token_idx == 0:
52 | allowed_tokens = added_ent_type_tokens
53 | else:
54 | allowed_tokens = added_pk_tokens + [1] # Adding end token
55 |
56 | closest_token = find_closest_token(
57 | token, embeddings, allowed_tokens, device
58 | )
59 | predict_pk_ids[idx, token_idx] = closest_token
60 |
61 | return predict_pk_ids
62 |
63 |
64 | class Predictor_E_Pk_Pv(nn.Module):
65 | def __init__(
66 | self,
67 | pretrained_model_name,
68 | max_seq_length,
69 | max_num_entity,
70 | tokenizer,
71 | ):
72 | super(Predictor_E_Pk_Pv, self).__init__()
73 |
74 | self.tokenizer = tokenizer
75 | self.max_seq_length = max_seq_length
76 | self.max_num_entity = max_num_entity
77 |
78 | self.t5_model = T5ForConditionalGeneration.from_pretrained(
79 | pretrained_model_name
80 | )
81 |
82 | def _prompt_decoder_return_logits(
83 | self,
84 | prompt_tokenized,
85 | input_ids,
86 | encoder_outputs,
87 | attention_mask,
88 | labels,
89 | added_ent_type_tokens=None,
90 | added_pk_tokens=None,
91 | ):
92 | start_token_id = self.tokenizer.pad_token_id
93 | end_token_id = self.tokenizer.eos_token_id
94 | start_tokens = torch.full(
95 | (labels.size(0), 1), start_token_id, dtype=torch.long, device=labels.device
96 | )
97 |
98 | combined_decoder_input_ids = torch.cat(
99 | [start_tokens, prompt_tokenized, labels], dim=-1
100 | )
101 | # Move all non-zero of each row to the beginning, except the first position be 0
102 | temp_result = torch.zeros_like(combined_decoder_input_ids)
103 | for i, row in enumerate(combined_decoder_input_ids):
104 | non_zeros = row[row != 0] # Extract non-zero elements
105 | temp_result[i, :1] = 0 # Keep the first element as 0
106 | temp_result[i, 1 : 1 + len(non_zeros)] = (
107 | non_zeros # Place non-zero elements
108 | )
109 | combined_decoder_input_ids = temp_result
110 |
111 | attention_mask_decoder = (
112 | combined_decoder_input_ids != self.tokenizer.pad_token_id
113 | ).long()
114 | attention_mask_decoder[:, 0] = 1
115 |
116 | # print("attention_mask_decoder:", attention_mask_decoder.sum(1))
117 | # print("combined_decoder_input_ids:", combined_decoder_input_ids)
118 | # print("attention_mask_decoder:", attention_mask_decoder)
119 |
120 | # if encoder_outputs is not None:
121 | logits = self.t5_model(
122 | # input_ids=input_ids,
123 | encoder_outputs=(
124 | encoder_outputs,
125 | ), # convert the encoder_outputs back to a tuple
126 | attention_mask=attention_mask,
127 | decoder_input_ids=combined_decoder_input_ids,
128 | decoder_attention_mask=attention_mask_decoder,
129 | ).logits[
130 | :, :-1
131 | ] # (batch_size, prompt_len + tgt_seq_len, vocab)
132 | # print("logits encoder_outputs:", logits.sum())
133 |
134 | # else:
135 | # logits = self.t5_model(
136 | # input_ids=input_ids,
137 | # # encoder_outputs=(encoder_outputs,), # convert the encoder_outputs back to a tuple
138 | # attention_mask=attention_mask,
139 | # decoder_input_ids=combined_decoder_input_ids,
140 | # decoder_attention_mask=attention_mask_decoder,
141 | # ).logits[:, :-1] # (batch_size, prompt_len + tgt_seq_len, vocab)
142 | # print("logits input_id:", logits.sum())
143 |
144 | """If output is: 123000 || abc000 """
145 | # logits = logits[:, -labels.size(1):] # (batch_size, tgt_seq_len, vocab)
146 |
147 | """If output is: 123abc || 000000 """
148 | # print("length:", labels.size(1))
149 | # Compute the start index for each row based on the number of non-zero values
150 | start_indices = (prompt_tokenized != 0).long().sum(dim=1)
151 | # print("start_indices:", start_indices)
152 | indices_range = torch.arange(labels.size(1)).unsqueeze(0).to(
153 | labels.device
154 | ) + start_indices.unsqueeze(1)
155 | # print("indices_range:", indices_range)
156 | batch_indices = torch.arange(logits.size(0)).unsqueeze(1).to(labels.device)
157 | # print("batch_indices:", batch_indices)
158 |
159 | logits = logits[
160 | batch_indices, indices_range
161 | ] # (batch_size, tgt_seq_len, vocab)
162 |
163 | if added_ent_type_tokens is not None and added_pk_tokens is not None:
164 | # Apply constraints: candidates can only be {ent_type, pk, eos}
165 | large_negative = -1e9
166 | mask_token = torch.full_like(logits, large_negative).to(labels.device)
167 | # Apply constraints to all positions
168 | for token_set in [added_ent_type_tokens, added_pk_tokens, [end_token_id]]:
169 | mask_token[:, :, token_set] = 0
170 |
171 | logits += mask_token
172 |
173 | return logits
174 |
175 | def extract_entity_names(self, label_string, tokenizer):
176 | sep_token = ""
177 | sep_token_id = tokenizer.convert_tokens_to_ids(sep_token)
178 | token_ids = tokenizer.encode(label_string, add_special_tokens=False)
179 |
180 | entities = []
181 | current_entity = []
182 | for token_id in token_ids:
183 | if token_id == sep_token_id:
184 | if current_entity:
185 | entity_name = tokenizer.decode(
186 | current_entity, skip_special_tokens=True
187 | )
188 | entities.append(entity_name)
189 | current_entity = []
190 | else:
191 | current_entity.append(token_id)
192 |
193 | return entities
194 |
195 | def inference_generate_ids(
196 | self,
197 | prompt_tokenized,
198 | input_ids,
199 | attention_mask,
200 | max_length,
201 | added_ent_type_tokens=None,
202 | added_pk_tokens=None,
203 | ):
204 | batch_size = input_ids.size(0)
205 | # Initialize a tensor of ones
206 | all_predict_ids = torch.ones(
207 | batch_size, max_length, dtype=torch.long, device=input_ids.device
208 | )
209 |
210 | # Determine the suppress_tokens based on added_ent_type_tokens and added_pk_tokens
211 | suppress_tokens = None
212 | if added_ent_type_tokens is not None and added_pk_tokens is not None:
213 | allowed_tokens = set(
214 | added_ent_type_tokens + added_pk_tokens + [1]
215 | ) # Including end token
216 | # print("allowed_tokens:", allowed_tokens)
217 | all_tokens = set(range(self.t5_model.config.vocab_size))
218 | suppress_tokens = list(all_tokens - allowed_tokens)
219 |
220 | for idx in range(batch_size):
221 | single_prompt = prompt_tokenized[idx, :].unsqueeze(0)
222 | single_cleaned_prompt = single_prompt[single_prompt != 0].unsqueeze(0)
223 |
224 | single_input_ids = input_ids[idx, :].unsqueeze(0)
225 | single_attention_mask = attention_mask[idx, :].unsqueeze(0)
226 |
227 | forced_decoder_ids = [
228 | [index + 1, element.item()]
229 | for index, element in enumerate(single_prompt[0])
230 | if element.item() != 0
231 | ]
232 |
233 | # Generate predictions with suppress_tokens if applicable
234 | generate_args = {
235 | "input_ids": single_input_ids,
236 | "attention_mask": single_attention_mask,
237 | "forced_decoder_ids": forced_decoder_ids,
238 | "max_length": max_length,
239 | }
240 | if suppress_tokens is not None:
241 | generate_args["suppress_tokens"] = suppress_tokens
242 |
243 | predict_ids = self.t5_model.generate(**generate_args)
244 | # if added_ent_type_tokens is not None and added_pk_tokens is not None:
245 | # print("single_prompt:", single_prompt)
246 | # print("predict_ids:", predict_ids)
247 |
248 | prompt_size = len(forced_decoder_ids)
249 | trimmed_predict_ids = predict_ids[
250 | :, prompt_size + 1 :
251 | ] # +1 due to the first generated token always being 0
252 |
253 | output_length = trimmed_predict_ids.size(1)
254 | all_predict_ids[idx, :output_length] = trimmed_predict_ids.squeeze(0)
255 |
256 | return all_predict_ids
257 |
258 | def forward(
259 | self,
260 | input_ids, # (b, seq_len)
261 | labels_ent_name, # (b, max_num_Entity * 6)
262 | real_labels_ent_name, # (b, max_num_Entity * 6)
263 | labels_pk, # (b, max_num_Entity, num_all_pks+2)
264 | labels_pv, # (b, max_num_Entity, num_all_pks, max_prop_len)
265 | attention_mask,
266 | attention_mask_ent_name,
267 | attention_mask_pk,
268 | attention_mask_pv,
269 | max_len_pv,
270 | device,
271 | added_ent_type_tokens,
272 | added_pk_tokens,
273 | mode="train",
274 | ):
275 | # Encode the input sequence once
276 | encoder_outputs = self.t5_model.encoder(
277 | input_ids=input_ids,
278 | attention_mask=attention_mask,
279 | output_hidden_states=False,
280 | return_dict=False,
281 | )[
282 | 0
283 | ] # it is a tuple (xxx,), we use [0] to choose xxx
284 |
285 | """Step1: logits_ent"""
286 | prompt_step1 = f"predict entities"
287 | prompt_tokenized = (
288 | self.tokenizer.encode(
289 | prompt_step1, add_special_tokens=False, return_tensors="pt"
290 | )
291 | .repeat(input_ids.size(0), 1)
292 | .to(input_ids.device)
293 | ) # (b, x)
294 |
295 | if mode == "train":
296 | logits_ent = self._prompt_decoder_return_logits(
297 | prompt_tokenized,
298 | input_ids,
299 | encoder_outputs,
300 | attention_mask,
301 | labels_ent_name,
302 | ) # (b, max_num_Entity * 6, vocab)
303 | elif mode == "test":
304 | predict_ent_ids = self.inference_generate_ids(
305 | prompt_tokenized, input_ids, attention_mask, self.max_num_entity * 6
306 | )
307 |
308 | """Step2: logits_pk"""
309 | max_prompt_length = 20
310 | batch_size, seq_len = input_ids.shape
311 |
312 | input_ids_ent_batch = []
313 | enc_outputs_ent_batch = []
314 | attention_mask_ent_batch = []
315 |
316 | # Initialize a list for storing padded decoder inputs
317 | prompt_padded_batch_step2 = []
318 |
319 | for sample_idx in range(batch_size):
320 | if mode == "train":
321 | # Extract the entity names from the ground truth labels
322 | entity_names = self.extract_entity_names(
323 | real_labels_ent_name[sample_idx], self.tokenizer
324 | )
325 | elif mode == "test":
326 | # Decode the predicted entity names from Step 1
327 | predicted_ent = [
328 | self.tokenizer.decode(ids, skip_special_tokens=False)
329 | for ids in predict_ent_ids
330 | ]
331 | entity_names = self.extract_entity_names(
332 | predicted_ent[sample_idx], self.tokenizer
333 | )
334 | if entity_names == []:
335 | entity_names = ["Fail to predict"]
336 |
337 | for entity_name in entity_names:
338 | # Format the input for the T5 decoder
339 | prompt_sample = f"predict type and properties {entity_name}"
340 | # print("prompt_sample step2:", prompt_sample)
341 | prompt_sample_tokenized = self.tokenizer.encode(
342 | prompt_sample, add_special_tokens=False, return_tensors="pt"
343 | ).to(input_ids.device)
344 |
345 | # Pad the tokenized input to the max_decoder_length
346 | prompt_padded_sample = torch.nn.functional.pad(
347 | prompt_sample_tokenized,
348 | # (max_prompt_length - prompt_sample_tokenized.shape[1], 0), # Add padding at the beginning
349 | (0, max_prompt_length - prompt_sample_tokenized.shape[1]),
350 | value=self.tokenizer.pad_token_id,
351 | )
352 |
353 | # Add to the list
354 | prompt_padded_batch_step2.append(prompt_padded_sample)
355 |
356 | # Repeat input_ids and attention_mask for each entity
357 | repeated_input_ids = input_ids[sample_idx].unsqueeze(0)
358 | repeated_enc_outputs = encoder_outputs[sample_idx].unsqueeze(0)
359 | repeated_attention_mask = attention_mask[sample_idx].unsqueeze(0)
360 | input_ids_ent_batch.append(repeated_input_ids)
361 | enc_outputs_ent_batch.append(repeated_enc_outputs)
362 | attention_mask_ent_batch.append(repeated_attention_mask)
363 |
364 | # Concatenate the repeated input_ids and attention masks to form a batch
365 | input_ids_ent_batch = torch.cat(
366 | input_ids_ent_batch, dim=0
367 | ) # (num_ent_in_batch, 512)
368 | enc_outputs_ent_batch = torch.cat(
369 | enc_outputs_ent_batch, dim=0
370 | ) # (num_ent_in_batch, 512, 512)
371 | attention_mask_ent_batch = torch.cat(
372 | attention_mask_ent_batch, dim=0
373 | ) # (num_ent_in_batch, 512)
374 | prompt_padded_batch_step2 = torch.cat(
375 | prompt_padded_batch_step2, dim=0
376 | ) # (num_ent_in_batch, max_prompt_length)
377 |
378 | if mode == "train":
379 | # remove all 0 tensors, since those are only for padding
380 | labels_pk_flatten = labels_pk[
381 | (labels_pk != 0).any(dim=2)
382 | ] # (num_ent_in_batch, tgt_seq_len)
383 | logits_pk = self._prompt_decoder_return_logits(
384 | prompt_padded_batch_step2,
385 | input_ids_ent_batch,
386 | enc_outputs_ent_batch,
387 | attention_mask_ent_batch,
388 | labels_pk_flatten,
389 | added_ent_type_tokens=added_ent_type_tokens,
390 | added_pk_tokens=list(set(added_pk_tokens) - set(["pk_entity_name"])),
391 | ) # (num_ent_in_batch, tgt_seq_len, vocab)
392 | elif mode == "test":
393 | predict_pk_ids = self.inference_generate_ids(
394 | prompt_padded_batch_step2,
395 | input_ids_ent_batch,
396 | attention_mask_ent_batch,
397 | max_prompt_length + 10,
398 | added_ent_type_tokens=added_ent_type_tokens,
399 | added_pk_tokens=list(set(added_pk_tokens) - set(["pk_entity_name"])),
400 | )
401 | # print("predict_pk_ids:", predict_pk_ids)
402 |
403 | """Step3: logits_pv"""
404 | max_prompt_length = 30 # Set a suitable max length for the decoder prompts
405 |
406 | input_ids_pv_batch = []
407 | enc_outputs_pv_batch = []
408 | attention_mask_pv_batch = []
409 | prompt_padded_batch_step3 = []
410 |
411 | for sample_idx in range(batch_size):
412 | if mode == "train":
413 | # Extract the entity names from the ground truth labels
414 | entity_names = self.extract_entity_names(
415 | real_labels_ent_name[sample_idx], self.tokenizer
416 | )
417 | elif mode == "test":
418 | # Decode the predicted entity names from Step 1
419 | predicted_ent = [
420 | self.tokenizer.decode(ids, skip_special_tokens=False)
421 | for ids in predict_ent_ids
422 | ]
423 | entity_names = self.extract_entity_names(
424 | predicted_ent[sample_idx], self.tokenizer
425 | )
426 | labels_pk = predict_pk_ids.unsqueeze(0)
427 |
428 | for entity_idx, entity_name in enumerate(entity_names):
429 | # Get the entity type
430 | entity_type_id = labels_pk[sample_idx, entity_idx, 0]
431 | entity_type_token = self.tokenizer.decode([entity_type_id])
432 |
433 | # Iterate through each property key for the entity, starting from column index 1 in labels_pk
434 | for pk_idx in range(1, labels_pk.size(2) - 1):
435 | pk_id = labels_pk[sample_idx, entity_idx, pk_idx]
436 | if pk_id in [
437 | self.tokenizer.pad_token_id,
438 | self.tokenizer.eos_token_id,
439 | ]:
440 | continue
441 |
442 | # Retrieve property key token
443 | pk_token = self.tokenizer.decode([pk_id])
444 |
445 | # Format the input for the T5 decoder
446 | prompt_sample = f"Predict property value for {entity_name} {entity_type_token} {pk_token}"
447 | # print("prompt_sample step3:", prompt_sample)
448 | prompt_sample_tokenized = self.tokenizer.encode(
449 | prompt_sample, add_special_tokens=False, return_tensors="pt"
450 | ).to(input_ids.device)
451 |
452 | # Pad the tokenized input to the max_prompt_length
453 | prompt_padded_sample = torch.nn.functional.pad(
454 | prompt_sample_tokenized,
455 | # (max_prompt_length - prompt_sample_tokenized.shape[1], 0),
456 | (0, max_prompt_length - prompt_sample_tokenized.shape[1]),
457 | value=self.tokenizer.pad_token_id,
458 | )
459 |
460 | # Add to the list
461 | prompt_padded_batch_step3.append(prompt_padded_sample)
462 |
463 | # Repeat input_ids and attention_mask for each property key
464 | repeated_input_ids = input_ids[sample_idx].unsqueeze(0)
465 | repeated_enc_outputs = encoder_outputs[sample_idx].unsqueeze(0)
466 | repeated_attention_mask = attention_mask[sample_idx].unsqueeze(0)
467 | input_ids_pv_batch.append(repeated_input_ids)
468 | enc_outputs_pv_batch.append(repeated_enc_outputs)
469 | attention_mask_pv_batch.append(repeated_attention_mask)
470 |
471 | # Concatenate the repeated input_ids to form a batch
472 | if (
473 | input_ids_pv_batch != []
474 | ): # Edge case for inference, sometimes the model predicts nothing for step2
475 | input_ids_pv_batch = torch.cat(
476 | input_ids_pv_batch, dim=0
477 | ) # (num_pair_batch, seq_len)
478 | enc_outputs_pv_batch = torch.cat(
479 | enc_outputs_pv_batch, dim=0
480 | ) # (num_pair_batch, seq_len)
481 | attention_mask_pv_batch = torch.cat(
482 | attention_mask_pv_batch, dim=0
483 | ) # (num_pair_batch, seq_len)
484 | prompt_padded_batch_step3 = torch.cat(
485 | prompt_padded_batch_step3, dim=0
486 | ) # (num_pair_batch, max_prompt_length)
487 |
488 | # if enc_outputs_pv_batch != []: # Edge case for inference, sometimes the model predicts nothing for step2
489 | #
490 | # attention_mask_pv_batch = torch.cat(attention_mask_pv_batch, dim=0) # (num_pair_batch, seq_len)
491 | # prompt_padded_batch_step3 = torch.cat(prompt_padded_batch_step3,
492 | # dim=0) # (num_pair_batch, max_prompt_length)
493 |
494 | if mode == "train":
495 | if input_ids_pv_batch != []:
496 | # remove all 0 tensors, since those are only for padding
497 | labels_pv_flatten = labels_pv[
498 | (labels_pv != self.tokenizer.pad_token_id).any(dim=3)
499 | ] # Flatten the labels_pv tensor
500 | # Predict the property values
501 | logits_pv = self._prompt_decoder_return_logits(
502 | prompt_padded_batch_step3,
503 | input_ids_pv_batch,
504 | enc_outputs_pv_batch,
505 | attention_mask_pv_batch,
506 | labels_pv_flatten,
507 | ) # (num_pair_batch, max_prop_len, vocab)
508 | else: # There is no other pk in addition to pk_ent_name. So we do not need to predict anything
509 | logits_pv = None
510 | elif mode == "test":
511 | if (
512 | input_ids_pv_batch != []
513 | ): # Edge case for inference, sometimes the model predicts nothing for step2
514 | # Generate the property value predictions
515 | predict_pv_ids = self.inference_generate_ids(
516 | prompt_padded_batch_step3,
517 | input_ids_pv_batch,
518 | attention_mask_pv_batch,
519 | max_length=max_prompt_length + 20,
520 | )
521 | else:
522 | predict_pv_ids = torch.tensor([[0]])
523 |
524 | if mode == "train":
525 | return logits_ent, logits_pk, logits_pv
526 | elif mode == "test":
527 | return predict_ent_ids, predict_pk_ids, predict_pv_ids
528 |
529 |
530 | class Trainer_E_Pk_Pv(Trainer):
531 | @staticmethod
532 | def calculate_loss(
533 | logits_ent,
534 | logits_pk,
535 | logits_pv,
536 | labels_ent_name, # (b, max_num_Entity * 6)
537 | labels_pk, # (b, max_num_Entity, num_all_pks+2)
538 | labels_pv, # (b, max_num_Entity, num_all_pks, max_prop_len)
539 | attention_mask_ent_name,
540 | attention_mask_pk,
541 | attention_mask_pv,
542 | added_ent_type_tokens,
543 | loss_mode,
544 | device,
545 | ):
546 | def calculate_loss_for_step(
547 | logits,
548 | labels,
549 | attention_mask,
550 | criterion,
551 | loss_mode,
552 | added_ent_type_tokens=None,
553 | end_token_id=1,
554 | sep_token_id=32098,
555 | ):
556 | # Flatten logits and labels
557 | logits_flatten = logits.reshape(-1, logits.size(-1))
558 | labels_flatten = labels.view(-1)
559 |
560 | # Compute token-wise loss
561 | loss_tokenwise = criterion(logits_flatten, labels_flatten)
562 |
563 | # Apply attention mask
564 | loss_masked = loss_tokenwise.view_as(labels) * attention_mask
565 |
566 | # Initialize weights with ones
567 | weights = torch.ones_like(labels, dtype=torch.float, device=labels.device)
568 |
569 | # Increase weight for sep_token_id
570 | weights[labels == sep_token_id] = 2
571 | # weights[labels == end_token_id] = 2
572 |
573 | # # Increase weights for tokens in added_ent_type_tokens, if specified
574 | # if added_ent_type_tokens is not None:
575 | # for token_id in added_ent_type_tokens:
576 | # weights[labels == token_id] = 3 # Or another weight factor as desired
577 |
578 | # Apply weights to the loss
579 | loss_weighted = loss_masked * weights
580 |
581 | # Compute average loss
582 | valid_tokens_mask = attention_mask == 1
583 | if loss_mode == "mean":
584 | loss = (
585 | loss_weighted.sum(dim=-1) / valid_tokens_mask.sum(dim=-1)
586 | ).mean()
587 | elif loss_mode == "sum":
588 | loss = (loss_weighted.sum(dim=-1) / valid_tokens_mask.sum(dim=-1)).sum()
589 |
590 | return loss
591 |
592 | criterion = torch.nn.CrossEntropyLoss(
593 | reduction="none"
594 | ) # Initialize without parameters
595 |
596 | """Step1: loss_ent"""
597 | # labels_ent_tokenized # (b, max_num_Entity * 3)
598 | loss_ent = calculate_loss_for_step(
599 | logits_ent, labels_ent_name, attention_mask_ent_name, criterion, loss_mode
600 | )
601 |
602 | """Step2: loss_pk"""
603 | labels_pk_flatten = labels_pk[
604 | (labels_pk != 0).any(dim=2)
605 | ] # (num_ent_batch, tgt_seq_len)
606 | attention_mask_pk_flatten = attention_mask_pk[
607 | (attention_mask_pk != 0).any(dim=2)
608 | ] # (num_ent_batch, tgt_seq_len)
609 | loss_pk = calculate_loss_for_step(
610 | logits_pk,
611 | labels_pk_flatten,
612 | attention_mask_pk_flatten,
613 | criterion,
614 | loss_mode,
615 | added_ent_type_tokens,
616 | )
617 |
618 | """Step3: loss_pv"""
619 | if logits_pv is not None:
620 | labels_pv_flatten = labels_pv[
621 | (labels_pv != 0).any(dim=3)
622 | ] # (num_pair_batch, pv_seq_len)
623 | attention_mask_pv_flatten = attention_mask_pv[
624 | (attention_mask_pv != 0).any(dim=3)
625 | ] # (num_pair_batch, pv_seq_len)
626 | loss_pv = calculate_loss_for_step(
627 | logits_pv,
628 | labels_pv_flatten,
629 | attention_mask_pv_flatten,
630 | criterion,
631 | loss_mode,
632 | )
633 | else:
634 | loss_pv = torch.tensor(0).to(loss_pk.device)
635 |
636 | return loss_ent, loss_pk, loss_pv
637 |
638 | @staticmethod
639 | def compute_batch_loss(
640 | batch,
641 | model,
642 | added_ent_type_tokens,
643 | added_pk_tokens,
644 | loss_mode,
645 | device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
646 | ):
647 | input_ids = batch["input_ids"].to(device) # (b, seq_len)
648 | labels_ent_name = batch["labels_ent_name"].to(device) # (b, max_num_Entity * 6)
649 | real_labels_ent_name = batch["real_labels_ent_name"] # (b, max_num_Entity * 6)
650 | labels_pk = batch["labels_pk"].to(device) # (b, max_num_Entity, num_all_pks+2)
651 | labels_pv = batch["labels_pv"].to(
652 | device
653 | ) # (b, max_num_Entity, num_all_pks, max_prop_len)
654 | attention_mask = batch["attention_mask"].to(device)
655 | attention_mask_ent_name = batch["attention_mask_ent_name"].to(device)
656 | attention_mask_pk = batch["attention_mask_pk"].to(device)
657 | attention_mask_pv = batch["attention_mask_pv"].to(device)
658 | max_len_pv = attention_mask_pv.shape[-1]
659 |
660 | # print("labels_ent_name:", labels_ent_name)
661 | # print("labels_pk:", labels_pk)
662 | # print("labels_pv:", labels_pv)
663 |
664 | logits_ent, logits_pk, logits_pv = model(
665 | input_ids, # (b, seq_len)
666 | labels_ent_name, # (b, max_num_Entity * 6)
667 | real_labels_ent_name, # (b, max_num_Entity * 6)
668 | labels_pk, # (b, max_num_Entity, num_all_pks+2)
669 | labels_pv, # (b, max_num_Entity, num_all_pks, max_prop_len)
670 | attention_mask,
671 | attention_mask_ent_name,
672 | attention_mask_pk,
673 | attention_mask_pv,
674 | max_len_pv,
675 | device,
676 | added_ent_type_tokens,
677 | added_pk_tokens,
678 | )
679 | (
680 | loss_ent,
681 | loss_pk,
682 | loss_pv,
683 | ) = Trainer_E_Pk_Pv.calculate_loss(
684 | logits_ent,
685 | logits_pk,
686 | logits_pv,
687 | labels_ent_name, # (b, max_num_Entity * 6)
688 | labels_pk, # (b, max_num_Entity, num_all_pks+2)
689 | labels_pv, # (b, max_num_Entity, num_all_pks, max_prop_len)
690 | attention_mask_ent_name,
691 | attention_mask_pk,
692 | attention_mask_pv,
693 | added_ent_type_tokens,
694 | loss_mode,
695 | device,
696 | )
697 | return loss_ent, loss_pk, loss_pv
698 |
699 | @staticmethod
700 | def evaluate_full_dataloader(
701 | dataloader,
702 | model,
703 | added_ent_type_tokens,
704 | added_pk_tokens,
705 | loss_mode,
706 | device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
707 | ):
708 | model.eval()
709 |
710 | total_loss_ent, total_loss_pk, total_loss_pv, total_loss = 0, 0, 0, 0
711 | with torch.no_grad():
712 | for batch in dataloader:
713 | loss_ent, loss_pk, loss_pv = Trainer_E_Pk_Pv.compute_batch_loss(
714 | batch,
715 | model,
716 | added_ent_type_tokens,
717 | added_pk_tokens,
718 | loss_mode,
719 | device,
720 | )
721 | total_loss_ent += loss_ent.item()
722 | total_loss_pk += loss_pk.item()
723 | total_loss_pv += loss_pv.item()
724 | total_loss += loss_ent.item() + loss_pk.item() + loss_pv.item()
725 | return (
726 | total_loss_ent / len(dataloader),
727 | total_loss_pk / len(dataloader),
728 | total_loss_pv / len(dataloader),
729 | total_loss / len(dataloader),
730 | )
731 |
732 | def train(
733 | self,
734 | save_path,
735 | model,
736 | train_dataloader,
737 | val_dataloader,
738 | optimizer,
739 | scheduler,
740 | epochs,
741 | **kwargs,
742 | ):
743 | device = kwargs.get(
744 | "device", torch.device("cuda" if torch.cuda.is_available() else "cpu")
745 | )
746 | log_wandb = kwargs.get("log_wandb", False)
747 | use_lora = kwargs.get("use_lora", "True")
748 | alpha = kwargs.get("alpha", 0.5)
749 | added_ent_type_tokens = kwargs.get("added_ent_type_tokens", None)
750 | added_pk_tokens = kwargs.get("added_pk_tokens", None)
751 | loss_mode = kwargs.get("loss_mode", None)
752 | report = kwargs.get("reporter", None)
753 | # # Start wandb
754 | # if log_wandb:
755 | # run = wandb.init(project="your-project-name", entity="your-entity-name")
756 |
757 | # Initialize variables for early stopping
758 | no_improve_step, no_improve_epochs = 0, 0
759 | min_train_loss, min_val_loss = float("inf"), float("inf")
760 |
761 | # Compute and log the initial loss before training
762 | print("Monitor Epoch loss...")
763 | (
764 | avg_train_loss_ent,
765 | avg_train_loss_pk,
766 | avg_train_loss_pv,
767 | avg_train_loss,
768 | ) = Trainer_E_Pk_Pv.evaluate_full_dataloader(
769 | train_dataloader,
770 | model,
771 | added_ent_type_tokens,
772 | added_pk_tokens,
773 | loss_mode,
774 | device=device,
775 | )
776 | (
777 | avg_val_loss_ent,
778 | avg_val_loss_pk,
779 | avg_val_loss_pv,
780 | avg_val_loss,
781 | ) = Trainer_E_Pk_Pv.evaluate_full_dataloader(
782 | val_dataloader,
783 | model,
784 | added_ent_type_tokens,
785 | added_pk_tokens,
786 | loss_mode,
787 | device=device,
788 | )
789 |
790 | # report(epoch=0, validation_loss=avg_val_loss)
791 | print(f"Epoch: {0}")
792 | print(
793 | f"Train: loss_ent: {avg_train_loss_ent}, loss_pk: {avg_train_loss_pk}, loss_pv: {avg_train_loss_pv}. loss: {avg_train_loss}"
794 | )
795 | print(
796 | f"Val: loss_ent: {avg_val_loss_ent}, loss_pk: {avg_val_loss_pk}, loss_pv: {avg_val_loss_pv}. loss: {avg_val_loss}"
797 | )
798 |
799 | if log_wandb:
800 | wandb.log({"Epoch": 0})
801 | wandb.log(
802 | {
803 | "train_loss_ent": avg_train_loss_ent,
804 | "train_loss_pk": avg_train_loss_pk,
805 | "train_loss_pv": avg_train_loss_pv,
806 | "train_loss": avg_train_loss,
807 | }
808 | )
809 | wandb.log(
810 | {
811 | "val_loss_ent": avg_val_loss_ent,
812 | "val_loss_pk": avg_val_loss_pk,
813 | "val_loss_pv": avg_val_loss_pv,
814 | "val_loss": avg_val_loss,
815 | }
816 | )
817 |
818 | # num_step = 0
819 | for epoch in range(epochs):
820 | # if use_lora:
821 | # print(f"t5_model.shared.original_module", model.t5_model.shared.original_module.weight.sum())
822 | # print(f"t5_model.shared.modules_to_save", model.t5_model.shared.modules_to_save['default'].weight.sum())
823 | start_time = time.time()
824 | model.train()
825 | for batch in train_dataloader:
826 | optimizer.zero_grad()
827 | (
828 | loss_ent,
829 | loss_pk,
830 | loss_pv,
831 | ) = Trainer_E_Pk_Pv.compute_batch_loss(
832 | batch, model, added_ent_type_tokens, added_pk_tokens, loss_mode
833 | )
834 |
835 | # batch_loss = loss_ent * alpha + loss_pk * (1 - alpha) / 2 + loss_pv * (1 - alpha) / 2
836 | batch_loss = loss_ent + loss_pk + loss_pv
837 | loss = batch_loss
838 | loss.backward()
839 | optimizer.step()
840 | scheduler.step()
841 |
842 | # Compute loss
843 | print("Monitor loss at epoch...")
844 | (
845 | avg_train_loss_ent,
846 | avg_train_loss_pk,
847 | avg_train_loss_pv,
848 | avg_train_loss,
849 | ) = Trainer_E_Pk_Pv.evaluate_full_dataloader(
850 | train_dataloader,
851 | model,
852 | added_ent_type_tokens,
853 | added_pk_tokens,
854 | loss_mode,
855 | device=device,
856 | )
857 | (
858 | avg_val_loss_ent,
859 | avg_val_loss_pk,
860 | avg_val_loss_pv,
861 | avg_val_loss,
862 | ) = Trainer_E_Pk_Pv.evaluate_full_dataloader(
863 | val_dataloader,
864 | model,
865 | added_ent_type_tokens,
866 | added_pk_tokens,
867 | loss_mode,
868 | device=device,
869 | )
870 |
871 | print(f"Epoch: {epoch + 1}")
872 | print(
873 | f"Train: loss_ent: {avg_train_loss_ent}, loss_pk: {avg_train_loss_pk}, loss_pv: {avg_train_loss_pv}. loss: {avg_train_loss}"
874 | )
875 | print(
876 | f"Val: loss_ent: {avg_val_loss_ent}, loss_pk: {avg_val_loss_pk}, loss_pv: {avg_val_loss_pv}. loss: {avg_val_loss}"
877 | )
878 | report(epoch=epoch + 1, validation_loss=avg_val_loss)
879 |
880 | if log_wandb:
881 | wandb.log({"Epoch": epoch + 1})
882 | wandb.log(
883 | {
884 | "train_loss_ent": avg_train_loss_ent,
885 | "train_loss_pk": avg_train_loss_pk,
886 | "train_loss_pv": avg_train_loss_pv,
887 | "train_loss": avg_train_loss,
888 | }
889 | )
890 | wandb.log(
891 | {
892 | "val_loss_ent": avg_val_loss_ent,
893 | "val_loss_pk": avg_val_loss_pk,
894 | "val_loss_pv": avg_val_loss_pv,
895 | "val_loss": avg_val_loss,
896 | }
897 | )
898 |
899 | # Check for early stopping
900 | if avg_val_loss < min_val_loss:
901 | print(f"Save model... (epoch)")
902 | no_improve_epochs = 0
903 | min_val_loss = avg_val_loss
904 | os.makedirs(os.path.dirname(save_path), exist_ok=True)
905 |
906 | if use_lora:
907 | model_path = f"{save_path}"
908 | model.t5_model.shared.modules_to_save["default"] = (
909 | model.t5_model.shared.original_module
910 | )
911 | model.save_pretrained(model_path)
912 | # model = model.merge_and_unload()
913 | else:
914 | model_path = f"{save_path}.pt"
915 | torch.save(model.state_dict(), model_path)
916 | else:
917 | no_improve_epochs += 1
918 |
919 | if no_improve_epochs == 10:
920 | print(f"Early stopping at epoch {epoch + 1}")
921 | break
922 | print("Time for epoch:", time.time() - start_time)
923 | # if log_wandb:
924 | # run.finish()
925 |
926 | def evaluate(self, model, test_dataloader, tokenizer, **kwargs):
927 | pass
928 |
929 | @staticmethod
930 | def generate_full_json_output(
931 | model,
932 | dataloader,
933 | added_ent_type_tokens,
934 | added_pk_tokens,
935 | tokenizer,
936 | device,
937 | mode,
938 | ):
939 |
940 | def extract_elements(given_list):
941 | extracted = []
942 | for item in given_list:
943 | elements = item.split("")
944 | for element in elements:
945 | cleaned_element = element.strip()
946 | if cleaned_element and not cleaned_element.startswith("<"):
947 | extracted.append(cleaned_element)
948 | return extracted
949 |
950 | model.eval()
951 | results = []
952 |
953 | text_index = 0
954 |
955 | with torch.no_grad():
956 | for batch in dataloader:
957 | start_time = time.time()
958 | input_ids = batch["input_ids"].to(device) # (b, seq_len)
959 | # labels_ent = batch["labels_ent"].to(device) # (b, max_num_Entity * 2)
960 | # labels_ent_tokenized = batch["labels_ent_tokenized"].to(device) # (b, max_num_Entity * 3)
961 | real_labels_ent_name = batch[
962 | "real_labels_ent_name"
963 | ] # (b, max_num_Entity * 6)
964 | labels_ent_name = batch["labels_ent_name"].to(
965 | device
966 | ) # (b, max_num_Entity * 6)
967 | labels_pk = batch["labels_pk"].to(
968 | device
969 | ) # (b, max_num_Entity, num_all_pks+2)
970 | labels_pv = batch["labels_pv"].to(
971 | device
972 | ) # (b, max_num_Entity, num_all_pks, max_prop_len)
973 | attention_mask = batch["attention_mask"].to(device)
974 | # attention_mask_ent = batch["attention_mask_ent"].to(device)
975 | # attention_mask_ent_tokenized = batch["attention_mask_ent_tokenized"].to(device)
976 | attention_mask_ent_name = batch["attention_mask_ent_name"].to(device)
977 | attention_mask_pk = batch["attention_mask_pk"].to(device)
978 | attention_mask_pv = batch["attention_mask_pv"].to(device)
979 | max_len_pv = attention_mask_pv.shape[-1]
980 |
981 | predict_ent_ids, predict_pk_ids, predict_pv_ids = model(
982 | input_ids, # (b, seq_len)
983 | labels_ent_name, # (b, max_num_Entity * 6)
984 | real_labels_ent_name, # (b, max_num_Entity * 6)
985 | labels_pk, # (b, max_num_Entity, num_all_pks+2)
986 | labels_pv, # (b, max_num_Entity, num_all_pks, max_prop_len)
987 | attention_mask,
988 | attention_mask_ent_name,
989 | attention_mask_pk,
990 | attention_mask_pv,
991 | max_len_pv,
992 | device,
993 | added_ent_type_tokens,
994 | added_pk_tokens,
995 | mode="test",
996 | )
997 |
998 | # print("labels_pv:", labels_pv.shape)
999 | # print("predict_ent_ids:", predict_ent_ids.shape, predict_ent_ids)
1000 | # print("predict_pk_ids:", predict_pk_ids.shape, predict_pk_ids)
1001 | # print("predict_pv_ids:", predict_pv_ids.shape, predict_pv_ids)
1002 |
1003 | # predict_pk_ids = replace_with_closest_embedding(predict_pk_ids, added_ent_type_tokens, added_pk_tokens,
1004 | # model, device)
1005 |
1006 | # print("predict_pk_ids:", predict_pk_ids)
1007 |
1008 | """Format prediction"""
1009 | predict_ent_tokens = [
1010 | tokenizer.decode(ids, skip_special_tokens=False)
1011 | for ids in predict_ent_ids
1012 | ]
1013 | res_predict_ent_name = extract_elements(predict_ent_tokens)
1014 | # predict_pk_tokens = [tokenizer.decode(ids, skip_special_tokens=True) for ids in predict_pk_ids]
1015 | # predict_pv_tokens = [tokenizer.decode(ids, skip_special_tokens=True) for ids in predict_pv_ids]
1016 |
1017 | # Calculate the number of properties for each entity, subtracting 1 for the type
1018 | num_properties_per_entity = [
1019 | (pk_ids != 1).sum().item() - 1 for pk_ids in predict_pk_ids
1020 | ]
1021 |
1022 | # Decode predict_pk_ids
1023 | predict_pk_tokens = [
1024 | tokenizer.decode(pk_ids, skip_special_tokens=True)
1025 | for pk_ids in predict_pk_ids
1026 | ]
1027 |
1028 | # Initialize variables
1029 | predict_pv_tokens = []
1030 | start_index = 0
1031 |
1032 | # Process each entity's property values
1033 | for num_props in num_properties_per_entity:
1034 | # Extract the property values for this entity
1035 | entity_pv_ids = predict_pv_ids[
1036 | start_index : start_index + num_props
1037 | ]
1038 |
1039 | # Decode and handle value 1 as empty string
1040 | entity_pv_tokens = [
1041 | (
1042 | tokenizer.decode(pv_ids, skip_special_tokens=True)
1043 | if pv_ids[0] != 1
1044 | else ""
1045 | )
1046 | for pv_ids in entity_pv_ids
1047 | ]
1048 |
1049 | # Append to the main list
1050 | predict_pv_tokens.append(entity_pv_tokens)
1051 |
1052 | # Update the start index for the next entity
1053 | start_index += num_props
1054 |
1055 | """Format ground truth"""
1056 | res_real_ent_name = extract_elements(real_labels_ent_name)
1057 |
1058 | labels_pk = labels_pk[(labels_pk != 0).any(dim=-1)]
1059 | labels_pk_flat = labels_pk.view(
1060 | -1, labels_pk.size(-1)
1061 | ) # Flattening to 2D
1062 | res_real_pk = [
1063 | tokenizer.decode(ids, skip_special_tokens=True)
1064 | for ids in labels_pk_flat
1065 | ]
1066 |
1067 | res_real_pv = []
1068 | for block in labels_pv[0]: # Assuming the first dimension is always 1
1069 | # Filter out zero rows
1070 | filtered_block = block[(block != 0).any(dim=-1)]
1071 |
1072 | # Check if the filtered block is not empty
1073 | if filtered_block.size(0) != 0:
1074 | # Decode each row in the filtered block
1075 | decoded_rows = [
1076 | tokenizer.decode(row, skip_special_tokens=True)
1077 | for row in filtered_block
1078 | ]
1079 | res_real_pv.append(decoded_rows)
1080 |
1081 | # Append the results of this batch to the 'batches' key in the results dictionary
1082 | results.append(
1083 | {
1084 | "predict_ent": res_predict_ent_name,
1085 | "predict_pk": predict_pk_tokens,
1086 | "predict_pv": predict_pv_tokens,
1087 | }
1088 | )
1089 |
1090 | # Optional: Print the current batch results
1091 | print("Batch", text_index)
1092 | print("truth_ent:", res_real_ent_name)
1093 | print("truth_pk:", res_real_pk)
1094 | print("truth_pv:", res_real_pv)
1095 | print()
1096 |
1097 | print("predict_ent:", res_predict_ent_name)
1098 | print("predict_pk:", predict_pk_tokens)
1099 | print("predict_pv:", predict_pv_tokens)
1100 | print("---------------------")
1101 | text_index += 1
1102 |
1103 | return results
1104 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | import matplotlib.pyplot as plt
4 | import numpy as np
5 | import torch
6 | import torch.backends.cudnn as cudnn
7 | import torch.nn as nn
8 | from peft import LoraConfig, PrefixTuningConfig, TaskType, get_peft_model
9 | from scipy.sparse import csr_matrix
10 | from scipy.sparse.csgraph import min_weight_full_bipartite_matching
11 | from torch.nn import MultiheadAttention
12 | from transformers import (AutoModel, AutoModelForSeq2SeqLM, AutoTokenizer,
13 | DataCollatorWithPadding, GPT2LMHeadModel,
14 | GPT2Tokenizer, LlamaForCausalLM, LlamaTokenizer,
15 | OpenAIGPTLMHeadModel, OpenAIGPTTokenizer,
16 | OpenLlamaConfig, OpenLlamaForCausalLM,
17 | T5ForConditionalGeneration, T5Tokenizer, Trainer,
18 | TrainingArguments, TransfoXLLMHeadModel,
19 | TransfoXLTokenizer)
20 | from transformers.models.t5.modeling_t5 import T5Attention
21 |
22 |
23 | def set_seed(seed):
24 | torch.cuda.manual_seed(seed)
25 | torch.manual_seed(seed)
26 | np.random.seed(seed)
27 | random.seed(seed)
28 | cudnn.deterministic = True
29 | cudnn.benchmark = False
30 |
31 |
32 | def plot_property_stats(
33 | property_accuracies, property_counts, fig_size=(10, 5), path_name=None
34 | ):
35 | # Create figure and axis
36 | fig, ax1 = plt.subplots(figsize=fig_size)
37 |
38 | # Sort property_counts by value, high to low
39 | sorted_property_counts = {
40 | k: v
41 | for k, v in sorted(
42 | property_counts.items(), key=lambda item: item[1], reverse=True
43 | )
44 | }
45 |
46 | # Bar plot with property counts
47 | ax1.bar(
48 | sorted_property_counts.keys(),
49 | sorted_property_counts.values(),
50 | color="b",
51 | alpha=0.5,
52 | )
53 | ax1.set_xlabel("Property Name")
54 | ax1.set_ylabel("Counts", color="b")
55 | ax1.tick_params(axis="y", labelcolor="b")
56 |
57 | # Rotate x-labels 90 degrees
58 | plt.xticks(rotation=90)
59 |
60 | # Create a second y-axis that shares the same x-axis, we already handled the x-label with ax1
61 | ax2 = ax1.twinx()
62 |
63 | # Line plot with property accuracies
64 | # Ensure the properties in sorted_property_accuracies follow the same order in sorted_property_counts
65 | sorted_property_accuracies = {
66 | k: property_accuracies[k] for k in sorted_property_counts.keys()
67 | }
68 | ax2.plot(
69 | sorted_property_accuracies.keys(),
70 | sorted_property_accuracies.values(),
71 | color="r",
72 | )
73 | ax2.set_ylabel("Accuracy", color="r") # we already handled the x-label with ax1
74 | ax2.tick_params(axis="y", labelcolor="r")
75 |
76 | # Layout
77 | fig.tight_layout()
78 |
79 | # Save the figure if a path is provided
80 | if path_name is not None:
81 | plt.savefig(path_name, dpi=300, bbox_inches="tight")
82 |
83 | plt.show()
84 |
85 |
86 | def get_generative_model_and_tokenizer(config):
87 | if config.saved_model_path:
88 | print("Loading pretrained model at", config.saved_model_path)
89 |
90 | if config.generative_model in ("gpt2", "gpt2-large"):
91 | model_path = config.saved_model_path or config.generative_model
92 | kwargs = {
93 | "pretrained_model_name_or_path": model_path,
94 | # "device_map": 'auto',
95 | }
96 | if hasattr(config, "torch_dtype"):
97 | if config.torch_dtype == "float16":
98 | kwargs["torch_dtype"] = torch.float16
99 | elif config.torch_dtype != "float32":
100 | raise ValueError(
101 | f"torch_dtype: {config.torch_dtype} not recognized in config file."
102 | )
103 | tokenizer = GPT2Tokenizer.from_pretrained(config.generative_model)
104 | model = GPT2LMHeadModel.from_pretrained(**kwargs)
105 | tokenizer.pad_token = tokenizer.eos_token
106 | elif config.generative_model == "custom":
107 | config = OpenLlamaConfig(
108 | vocab_size=32000,
109 | hidden_size=config.hidden_size,
110 | intermediate_size=config.intermediate_size,
111 | num_hidden_layers=config.num_hidden_layers,
112 | max_position_embeddings=config.max_position_embeddings,
113 | )
114 | model = OpenLlamaForCausalLM(config=config)
115 | tokenizer = LlamaTokenizer.from_pretrained("openlm-research/open_llama_3b_v2")
116 | tokenizer.pad_token_id = 0
117 | elif config.generative_model == "llama_3B":
118 | model_path = config.saved_model_path or "openlm-research/open_llama_3b_v2"
119 | tokenizer = AutoTokenizer.from_pretrained(model_path)
120 | kwargs = {
121 | "pretrained_model_name_or_path": model_path,
122 | # "device_map": 'auto',
123 | }
124 | if hasattr(config, "torch_dtype"):
125 | if config.torch_dtype == "float16":
126 | kwargs["torch_dtype"] = torch.float16
127 | elif config.torch_dtype != "float32":
128 | raise ValueError(
129 | f"torch_dtype: {config.torch_dtype} not recognized in config file."
130 | )
131 | model = LlamaForCausalLM.from_pretrained(**kwargs)
132 | tokenizer.pad_token = tokenizer.eos_token
133 | # model.tie_weights()
134 | elif "flan-t5" in config.generative_model:
135 | model_path = "google/" + config.generative_model
136 | kwargs = {
137 | "pretrained_model_name_or_path": model_path,
138 | # "device_map": 'auto',
139 | }
140 | if hasattr(config, "torch_dtype"):
141 | if config.torch_dtype == "float16":
142 | kwargs["torch_dtype"] = torch.float16
143 | elif config.torch_dtype != "float32":
144 | raise ValueError(
145 | f"torch_dtype: {config.torch_dtype} not recognized in config file."
146 | )
147 | model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
148 | tokenizer = AutoTokenizer.from_pretrained(model_path)
149 | elif config.generative_model in ["t5-small", "t5-base", "t5-large"]:
150 | model_path = config.saved_model_path or config.generative_model
151 | kwargs = {
152 | "pretrained_model_name_or_path": model_path,
153 | }
154 | if hasattr(config, "torch_dtype"):
155 | if config.torch_dtype == "float16":
156 | kwargs["torch_dtype"] = torch.float16
157 | elif config.torch_dtype != "float32":
158 | raise ValueError(
159 | f"torch_dtype: {config.torch_dtype} not recognized in config file."
160 | )
161 | # model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
162 | # tokenizer = AutoTokenizer.from_pretrained(model_path)
163 | model = T5ForConditionalGeneration.from_pretrained(config.generative_model)
164 | # print("T5 model", model)
165 | tokenizer = T5Tokenizer.from_pretrained(config.generative_model)
166 | # Add the new tokens to the tokenizer
167 | new_tokens = ["", "", "", "{", "}"]
168 | tokenizer.add_tokens(new_tokens)
169 | else:
170 | raise ValueError(f'model name "{config.generative_model}" not recognized')
171 |
172 | # Apply LoRa training if config.use_lora is True
173 | if config.use_lora:
174 | lora_config = LoraConfig(
175 | r=config.lora_r,
176 | lora_alpha=config.lora_alpha,
177 | lora_dropout=config.lora_dropout,
178 | task_type=(
179 | TaskType.SEQ_2_SEQ_LM
180 | if "flan-t5" in config.generative_model
181 | else TaskType.CAUSAL_LM
182 | ),
183 | target_modules=config.lora_target_modules,
184 | modules_to_save=config.lora_modules_to_save,
185 | )
186 | model = get_peft_model(model, lora_config)
187 | model.print_trainable_parameters()
188 |
189 | return model, tokenizer
190 |
191 |
192 | def compute_inverse_frequency_weights(entity_type_counts, num_entity_types):
193 | # Extract counts
194 | counts = list(entity_type_counts.values())
195 | # Compute inverse frequency
196 | inverse_freq = [1.0 / count for count in counts]
197 | # Normalize (optional, but it helps in cases where you'd want the weights to be relative to the highest class weight)
198 | total = sum(inverse_freq)
199 | normalized_weights = [freq / total for freq in inverse_freq]
200 |
201 | return torch.tensor(normalized_weights, dtype=torch.float32)
202 |
203 |
204 | def print_trainable_parameters(model):
205 | trainable_params = 0
206 | all_param = 0
207 | for _, param in model.named_parameters():
208 | all_param += param.numel()
209 | if param.requires_grad:
210 | trainable_params += param.numel()
211 | print(
212 | f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param:.2f}"
213 | )
214 |
215 |
216 | def get_attention_paths(model, path=""):
217 | paths = []
218 | for name, module in model.named_children():
219 | new_path = f"{path}.{name}" if path else name
220 |
221 | if isinstance(module, (T5Attention)):
222 | paths.append(f"{new_path}.q")
223 | # paths.append(f"{new_path}.k")
224 | paths.append(f"{new_path}.v")
225 | # paths.append(f"{new_path}.o")
226 | else:
227 | paths.extend(get_attention_paths(module, new_path))
228 |
229 | return paths
230 |
231 |
232 | def get_transformerlayer_paths(model, path=""):
233 | paths = []
234 | for name, module in model.named_children():
235 | new_path = f"{path}.{name}" if path else name
236 |
237 | if isinstance(module, nn.TransformerEncoderLayer):
238 | paths.append(new_path)
239 | else:
240 | paths.extend(get_transformerlayer_paths(module, new_path))
241 |
242 | return paths
243 |
244 |
245 | def remove_duplicates_and_postprocess(entity_lst):
246 | def postprocess(entity):
247 | for key, value in entity.items():
248 | if not isinstance(value, str):
249 | if isinstance(value, list):
250 | try:
251 | value = " ".join(value)
252 | except:
253 | value = str(value)
254 | else:
255 | value = str(value)
256 | entity[key] = value
257 | return entity
258 |
259 | new_lst = []
260 | for e in entity_lst:
261 | e = postprocess(e)
262 | if not e in new_lst:
263 | new_lst.append(e)
264 | return new_lst
265 |
--------------------------------------------------------------------------------